diff --git a/.env b/.env new file mode 100644 index 0000000..fc7ebc4 --- /dev/null +++ b/.env @@ -0,0 +1,7 @@ +POSTGRESQL_HOST= +POSTGRESQL_PORT= +POSTGRESQL_USERNAME= +POSTGRESQL_PASSWORD= +POSTGRESQL_DATABASE= + +VISUAL=0 \ No newline at end of file diff --git a/.env.dev b/.env.dev new file mode 100644 index 0000000..817e26a --- /dev/null +++ b/.env.dev @@ -0,0 +1,7 @@ +POSTGRESQL_HOST=192.168.10.137 +POSTGRESQL_PORT=54321 +POSTGRESQL_USERNAME=postgres +POSTGRESQL_PASSWORD=123456 +POSTGRESQL_DATABASE=pdf-qa + +VISUAL=1 diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..fe192d2 --- /dev/null +++ b/.gitignore @@ -0,0 +1,5 @@ +venv +*.pdf +.vscode +visual_images/*.jpg +__pycache__ \ No newline at end of file diff --git a/README.md b/README.md index e69de29..85fb2e9 100644 --- a/README.md +++ b/README.md @@ -0,0 +1,21 @@ +# 环境说明 + +``` +# 1. pdf转图片需要安装以下依赖 +apt install -y poppler-utils +# 2. 解决paddle的兼容问题 +export FLAGS_enable_pir_api=0 +# 3. 安装paddlepaddle-gpu后需要配置的环境 +apt install libcudnn8 +apt install libcudnn8-dev +echo "export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/usr/local/lib/python3.11/site-packages/nvidia/cublas/lib/" >> /etc/profile +# 4. 手动安装MinerU需要用到的模型,下载路径: +# model_dir is: /root/.cache/modelscope/hub/models/opendatalab/PDF-Extract-Kit-1___0/models +# layoutreader_model_dir is: /root/.cache/modelscope/hub/models/ppaanngggg/layoutreader +# The configuration file has been configured successfully, the path is: /root/magic-pdf.json +# 需要将项目中的magic-pdf.json链接到/root/magic-pdf.json +ln -s `pwd`/magic-pdf.json /root/magic-pdf.json +python download_MinerU_models.py +# 5. python连接postgresql需要下载的依赖 +apt install postgresql postgresql-contrib libpq-dev +``` diff --git a/download_MinerU_models.py b/download_MinerU_models.py new file mode 100644 index 0000000..626473d --- /dev/null +++ b/download_MinerU_models.py @@ -0,0 +1,67 @@ +import json +import shutil +import os + +import requests +from modelscope import snapshot_download + + +def download_json(url): + # 下载JSON文件 + response = requests.get(url) + response.raise_for_status() # 检查请求是否成功 + return response.json() + + +def download_and_modify_json(url, local_filename, modifications): + if os.path.exists(local_filename): + data = json.load(open(local_filename)) + config_version = data.get('config_version', '0.0.0') + if config_version < '1.2.0': + data = download_json(url) + else: + data = download_json(url) + + # 修改内容 + for key, value in modifications.items(): + data[key] = value + + # 保存修改后的内容 + with open(local_filename, 'w', encoding='utf-8') as f: + json.dump(data, f, ensure_ascii=False, indent=4) + + +if __name__ == '__main__': + mineru_patterns = [ + # "models/Layout/LayoutLMv3/*", + "models/Layout/YOLO/*", + "models/MFD/YOLO/*", + "models/MFR/unimernet_hf_small_2503/*", + "models/OCR/paddleocr_torch/*", + # "models/TabRec/TableMaster/*", + # "models/TabRec/StructEqTable/*", + ] + model_dir = snapshot_download('opendatalab/PDF-Extract-Kit-1.0', allow_patterns=mineru_patterns) + layoutreader_model_dir = snapshot_download('ppaanngggg/layoutreader') + model_dir = model_dir + '/models' + print(f'model_dir is: {model_dir}') + print(f'layoutreader_model_dir is: {layoutreader_model_dir}') + + # paddleocr_model_dir = model_dir + '/OCR/paddleocr' + # user_paddleocr_dir = os.path.expanduser('~/.paddleocr') + # if os.path.exists(user_paddleocr_dir): + # shutil.rmtree(user_paddleocr_dir) + # shutil.copytree(paddleocr_model_dir, user_paddleocr_dir) + + json_url = 'https://gcore.jsdelivr.net/gh/opendatalab/MinerU@master/magic-pdf.template.json' + config_file_name = 'magic-pdf.json' + home_dir = os.path.expanduser('~') + config_file = os.path.join(home_dir, config_file_name) + + json_mods = { + 'models-dir': model_dir, + 'layoutreader-model-dir': layoutreader_model_dir, + } + + download_and_modify_json(json_url, config_file, json_mods) + print(f'The configuration file has been configured successfully, the path is: {config_file}') diff --git a/helper/content_recognition/main.py b/helper/content_recognition/main.py new file mode 100644 index 0000000..b51f2c9 --- /dev/null +++ b/helper/content_recognition/main.py @@ -0,0 +1,116 @@ +from typing import List +import cv2 +from .utils import scanning_document_classify, text_rec, table_rec, scanning_document_rec, markdown_rec, assign_tables_to_titles, remove_watermark +from tqdm import tqdm + + +class LayoutRecognitionResult(object): + + def __init__(self, clsid, content, box, table_title=None): + self.clsid = clsid + self.content = content + self.box = box + self.table_title = table_title + + def __repr__(self): + return f"[{self.clsid}] {self.content}" + + +expand_pixel = 10 + + +def rec(page_detection_results, tmp_dir) -> List[List[LayoutRecognitionResult]]: + page_recognition_results = [] + + for page_idx in tqdm(range(len(page_detection_results)), '文本识别'): + results = page_detection_results[page_idx] + if not results.boxes: + page_recognition_results.append([]) + continue + + img = cv2.imread(results.image_path) + h, w = img.shape[:2] + + for layout in results.boxes: + # box往外扩一点便于ocr + layout.pos[0] -= expand_pixel + layout.pos[1] -= expand_pixel + layout.pos[2] += expand_pixel + layout.pos[3] += expand_pixel + + layout.pos[0] = max(0, layout.pos[0]) + layout.pos[1] = max(0, layout.pos[1]) + layout.pos[2] = min(w, layout.pos[2]) + layout.pos[3] = min(h, layout.pos[3]) + + outputs = [] + + is_scanning_document = False + for layout in results.boxes: + x1, y1, x2, y2 = layout.pos + x1, y1, x2, y2 = int(x1), int(y1), int(x2), int(y2) + layout_img = img[y1: y2, x1: x2] + content = None + if layout.clsid == 0: + # text + content = markdown_rec(layout_img) + elif layout.clsid == 2: + # figure + if scanning_document_classify(layout_img): + # 扫描件 + is_scanning_document = True + content, layout_img = scanning_document_rec(layout_img) + source_page_no_watermark_img = remove_watermark(cv2.imread(f'{tmp_dir}/{page_idx + 1}.jpg')) + elif layout.clsid == 4: + # table + if scanning_document_classify(layout_img): + is_scanning_document = True + content, layout_img = scanning_document_rec(layout_img) + source_page_no_watermark_img = remove_watermark(cv2.imread(f'{tmp_dir}/{page_idx + 1}.jpg')) + else: + content = table_rec(layout_img) + elif layout.clsid == 5: + # table caption + ocr_results = text_rec(layout_img) + content = '' + for o in ocr_results: + content += f'{o}\n' + while content.endswith('\n'): + content = content[:-1] + + if not content: + continue + + result = LayoutRecognitionResult(layout.clsid, content, layout.pos) + outputs.append(result) + + if is_scanning_document and len(outputs) == 1: + # 扫描件额外提取标题 + h, w = source_page_no_watermark_img.shape[:2] + if h > w: + title_img = source_page_no_watermark_img[:360, :w, ...] + + # cv2.imwrite(f'/mnt/pdf2markdown/temp/{page_idx + 1}.jpg', title_img) + # vis = cv2.rectangle(source_page_no_watermark_img.copy(), (0, 0), (w, 360), (255, 255, 0), 3) + # cv2.imwrite(f'/mnt/pdf2markdown/temp/{page_idx + 1}-vis.jpg', vis) + else: + title_img = source_page_no_watermark_img[:410, :w, ...] + + # cv2.imwrite(f'/mnt/pdf2markdown/temp/{page_idx + 1}.jpg', title_img) + # vis = cv2.rectangle(source_page_no_watermark_img.copy(), (0, 310), (w, 410), (255, 255, 0), 3) + # cv2.imwrite(f'/mnt/pdf2markdown/temp/{page_idx + 1}-vis.jpg', vis) + title = text_rec(title_img) + outputs[0].table_title = '\n'.join(title) + else: + # 自动给表格分配距离它最近的标题 + assign_tables_to_titles(outputs) + + # 表格标题可以删掉了 + outputs = [_ for _ in outputs if _.clsid != 5] + # 将2-图片 和 4-表格转为数据库中的枚举 1-表格 + for o in outputs: + if o.clsid == 2 or o.clsid == 4: + o.clsid = 1 + page_recognition_results.append(outputs) + + return page_recognition_results diff --git a/helper/content_recognition/rapid_table_pipeline/__init__.py b/helper/content_recognition/rapid_table_pipeline/__init__.py new file mode 100644 index 0000000..702f7d0 --- /dev/null +++ b/helper/content_recognition/rapid_table_pipeline/__init__.py @@ -0,0 +1,5 @@ +# -*- encoding: utf-8 -*- +# @Author: SWHL +# @Contact: liekkaskono@163.com +from .main import RapidTable, RapidTableInput +from .utils import VisTable diff --git a/helper/content_recognition/rapid_table_pipeline/main.py b/helper/content_recognition/rapid_table_pipeline/main.py new file mode 100644 index 0000000..3dcfef4 --- /dev/null +++ b/helper/content_recognition/rapid_table_pipeline/main.py @@ -0,0 +1,258 @@ +# -*- encoding: utf-8 -*- +# @Author: SWHL +# @Contact: liekkaskono@163.com +import argparse +import copy +import importlib +import time +from dataclasses import asdict, dataclass +from enum import Enum +from pathlib import Path +from typing import Dict, List, Optional, Tuple, Union + +import cv2 +import numpy as np + +from rapid_table.utils import DownloadModel, LoadImage, Logger, VisTable + +from .table_matcher import TableMatch +from .table_structure import TableStructurer, TableStructureUnitable +from markdownify import markdownify as md + +logger = Logger(logger_name=__name__).get_log() +root_dir = Path(__file__).resolve().parent + + +class ModelType(Enum): + PPSTRUCTURE_EN = "ppstructure_en" + PPSTRUCTURE_ZH = "ppstructure_zh" + SLANETPLUS = "slanet_plus" + UNITABLE = "unitable" + + +ROOT_URL = "https://www.modelscope.cn/models/RapidAI/RapidTable/resolve/master/" +KEY_TO_MODEL_URL = { + ModelType.PPSTRUCTURE_EN.value: f"{ROOT_URL}/en_ppstructure_mobile_v2_SLANet.onnx", + ModelType.PPSTRUCTURE_ZH.value: f"{ROOT_URL}/ch_ppstructure_mobile_v2_SLANet.onnx", + ModelType.SLANETPLUS.value: f"{ROOT_URL}/slanet-plus.onnx", + ModelType.UNITABLE.value: { + "encoder": f"{ROOT_URL}/unitable/encoder.pth", + "decoder": f"{ROOT_URL}/unitable/decoder.pth", + "vocab": f"{ROOT_URL}/unitable/vocab.json", + }, +} + + +@dataclass +class RapidTableInput: + model_type: Optional[str] = ModelType.SLANETPLUS.value + model_path: Union[str, Path, None, Dict[str, str]] = None + use_cuda: bool = False + device: str = "cpu" + + +@dataclass +class RapidTableOutput: + pred_html: Optional[str] = None + cell_bboxes: Optional[np.ndarray] = None + logic_points: Optional[np.ndarray] = None + elapse: Optional[float] = None + + +class RapidTable: + def __init__(self, config: RapidTableInput): + self.model_type = config.model_type + if self.model_type not in KEY_TO_MODEL_URL: + model_list = ",".join(KEY_TO_MODEL_URL) + raise ValueError( + f"{self.model_type} is not supported. The currently supported models are {model_list}." + ) + + config.model_path = self.get_model_path(config.model_type, config.model_path) + if self.model_type == ModelType.UNITABLE.value: + self.table_structure = TableStructureUnitable(asdict(config)) + else: + self.table_structure = TableStructurer(asdict(config)) + + self.table_matcher = TableMatch() + + try: + self.ocr_engine = importlib.import_module("rapidocr").RapidOCR() + except ModuleNotFoundError: + self.ocr_engine = None + + self.load_img = LoadImage() + + def __call__( + self, + img_content: Union[str, np.ndarray, bytes, Path], + ocr_result: List[Union[List[List[float]], str, str]] = None, + ) -> RapidTableOutput: + if self.ocr_engine is None and ocr_result is None: + raise ValueError( + "One of two conditions must be met: ocr_result is not empty, or rapidocr is installed." + ) + + img = self.load_img(img_content) + + s = time.perf_counter() + h, w = img.shape[:2] + + if ocr_result is None: + ocr_result = self.ocr_engine(img) + ocr_result = list( + zip( + ocr_result.boxes, + ocr_result.txts, + ocr_result.scores, + ) + ) + dt_boxes, rec_res = self.get_boxes_recs(ocr_result, h, w) + + pred_structures, cell_bboxes, _ = self.table_structure(copy.deepcopy(img)) + + # 适配slanet-plus模型输出的box缩放还原 + if self.model_type == ModelType.SLANETPLUS.value: + cell_bboxes = self.adapt_slanet_plus(img, cell_bboxes) + + pred_html = self.table_matcher(pred_structures, cell_bboxes, dt_boxes, rec_res) + + # 过滤掉占位的bbox + mask = ~np.all(cell_bboxes == 0, axis=1) + cell_bboxes = cell_bboxes[mask] + + logic_points = self.table_matcher.decode_logic_points(pred_structures) + elapse = time.perf_counter() - s + return RapidTableOutput(pred_html, cell_bboxes, logic_points, elapse) + + def get_boxes_recs( + self, ocr_result: List[Union[List[List[float]], str, str]], h: int, w: int + ) -> Tuple[np.ndarray, Tuple[str, str]]: + dt_boxes, rec_res, scores = list(zip(*ocr_result)) + rec_res = list(zip(rec_res, scores)) + + r_boxes = [] + for box in dt_boxes: + box = np.array(box) + x_min = max(0, box[:, 0].min() - 1) + x_max = min(w, box[:, 0].max() + 1) + y_min = max(0, box[:, 1].min() - 1) + y_max = min(h, box[:, 1].max() + 1) + box = [x_min, y_min, x_max, y_max] + r_boxes.append(box) + dt_boxes = np.array(r_boxes) + return dt_boxes, rec_res + + def adapt_slanet_plus(self, img: np.ndarray, cell_bboxes: np.ndarray) -> np.ndarray: + h, w = img.shape[:2] + resized = 488 + ratio = min(resized / h, resized / w) + w_ratio = resized / (w * ratio) + h_ratio = resized / (h * ratio) + cell_bboxes[:, 0::2] *= w_ratio + cell_bboxes[:, 1::2] *= h_ratio + return cell_bboxes + + @staticmethod + def get_model_path( + model_type: str, model_path: Union[str, Path, None] + ) -> Union[str, Dict[str, str]]: + if model_path is not None: + return model_path + + model_url = KEY_TO_MODEL_URL.get(model_type, None) + if isinstance(model_url, str): + model_path = DownloadModel.download(model_url) + return model_path + + if isinstance(model_url, dict): + model_paths = {} + for k, url in model_url.items(): + model_paths[k] = DownloadModel.download( + url, save_model_name=f"{model_type}_{Path(url).name}" + ) + return model_paths + + raise ValueError(f"Model URL: {type(model_url)} is not between str and dict.") + + +def parse_args(arg_list: Optional[List[str]] = None): + parser = argparse.ArgumentParser() + parser.add_argument( + "-v", + "--vis", + action="store_true", + default=False, + help="Wheter to visualize the layout results.", + ) + parser.add_argument( + "-img", "--img_path", type=str, required=True, help="Path to image for layout." + ) + parser.add_argument( + "-m", + "--model_type", + type=str, + default=ModelType.SLANETPLUS.value, + choices=list(KEY_TO_MODEL_URL), + ) + args = parser.parse_args(arg_list) + return args + + +try: + ocr_engine = importlib.import_module("rapidocr").RapidOCR() +except ModuleNotFoundError as exc: + raise ModuleNotFoundError( + "Please install the rapidocr by pip install rapidocr" + ) from exc + +input_args = RapidTableInput(model_type=ModelType.SLANETPLUS.value) +table_engine = RapidTable(input_args) + +def table2md_pipeline(img): + rapid_ocr_output = ocr_engine(img) + ocr_result = list( + zip(rapid_ocr_output.boxes, rapid_ocr_output.txts, rapid_ocr_output.scores) + ) + table_results = table_engine(img, ocr_result) + html_content = table_results.pred_html + md_content = md(html_content) + return md_content + + +# def main(arg_list: Optional[List[str]] = None): +# args = parse_args(arg_list) + +# try: +# ocr_engine = importlib.import_module("rapidocr").RapidOCR() +# except ModuleNotFoundError as exc: +# raise ModuleNotFoundError( +# "Please install the rapidocr by pip install rapidocr" +# ) from exc + +# input_args = RapidTableInput(model_type=args.model_type) +# table_engine = RapidTable(input_args) + +# img = cv2.imread(args.img_path) + +# rapid_ocr_output = ocr_engine(img) +# ocr_result = list( +# zip(rapid_ocr_output.boxes, rapid_ocr_output.txts, rapid_ocr_output.scores) +# ) +# table_results = table_engine(img, ocr_result) +# print(table_results.pred_html) + +# viser = VisTable() +# if args.vis: +# img_path = Path(args.img_path) + +# save_dir = img_path.resolve().parent +# save_html_path = save_dir / f"{Path(img_path).stem}.html" +# save_drawed_path = save_dir / f"vis_{Path(img_path).name}" +# viser(img_path, table_results, save_html_path, save_drawed_path) + + +if __name__ == "__main__": + res = table2md_pipeline(cv2.imread('/mnt/research/PaddleOCR/pdf2md_pipeline/s4_content_recognition/all_layouts/11.jpg')) + print('*' * 50) + print(res) diff --git a/helper/content_recognition/rapid_table_pipeline/models/ch_ppstructure_mobile_v2_SLANet.onnx b/helper/content_recognition/rapid_table_pipeline/models/ch_ppstructure_mobile_v2_SLANet.onnx new file mode 100644 index 0000000..d8b58fc Binary files /dev/null and b/helper/content_recognition/rapid_table_pipeline/models/ch_ppstructure_mobile_v2_SLANet.onnx differ diff --git a/helper/content_recognition/rapid_table_pipeline/models/slanet-plus.onnx b/helper/content_recognition/rapid_table_pipeline/models/slanet-plus.onnx new file mode 100644 index 0000000..d263823 Binary files /dev/null and b/helper/content_recognition/rapid_table_pipeline/models/slanet-plus.onnx differ diff --git a/helper/content_recognition/rapid_table_pipeline/table_matcher/__init__.py b/helper/content_recognition/rapid_table_pipeline/table_matcher/__init__.py new file mode 100644 index 0000000..9bff7d7 --- /dev/null +++ b/helper/content_recognition/rapid_table_pipeline/table_matcher/__init__.py @@ -0,0 +1,4 @@ +# -*- encoding: utf-8 -*- +# @Author: SWHL +# @Contact: liekkaskono@163.com +from .matcher import TableMatch diff --git a/helper/content_recognition/rapid_table_pipeline/table_matcher/matcher.py b/helper/content_recognition/rapid_table_pipeline/table_matcher/matcher.py new file mode 100644 index 0000000..f579a12 --- /dev/null +++ b/helper/content_recognition/rapid_table_pipeline/table_matcher/matcher.py @@ -0,0 +1,199 @@ +# copyright (c) 2022 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# -*- encoding: utf-8 -*- +import numpy as np + +from .utils import compute_iou, distance + + +class TableMatch: + def __init__(self, filter_ocr_result=True, use_master=False): + self.filter_ocr_result = filter_ocr_result + self.use_master = use_master + + def __call__(self, pred_structures, cell_bboxes, dt_boxes, rec_res): + if self.filter_ocr_result: + dt_boxes, rec_res = self._filter_ocr_result(cell_bboxes, dt_boxes, rec_res) + matched_index = self.match_result(dt_boxes, cell_bboxes) + pred_html, pred = self.get_pred_html(pred_structures, matched_index, rec_res) + return pred_html + + def match_result(self, dt_boxes, cell_bboxes, min_iou=0.1**8): + matched = {} + for i, gt_box in enumerate(dt_boxes): + distances = [] + for j, pred_box in enumerate(cell_bboxes): + if len(pred_box) == 8: + pred_box = [ + np.min(pred_box[0::2]), + np.min(pred_box[1::2]), + np.max(pred_box[0::2]), + np.max(pred_box[1::2]), + ] + distances.append( + (distance(gt_box, pred_box), 1.0 - compute_iou(gt_box, pred_box)) + ) # compute iou and l1 distance + sorted_distances = distances.copy() + # select det box by iou and l1 distance + sorted_distances = sorted( + sorted_distances, key=lambda item: (item[1], item[0]) + ) + # must > min_iou + if sorted_distances[0][1] >= 1 - min_iou: + continue + + if distances.index(sorted_distances[0]) not in matched: + matched[distances.index(sorted_distances[0])] = [i] + else: + matched[distances.index(sorted_distances[0])].append(i) + return matched + + def get_pred_html(self, pred_structures, matched_index, ocr_contents): + end_html = [] + td_index = 0 + for tag in pred_structures: + if "" not in tag: + end_html.append(tag) + continue + + if "" == tag: + end_html.extend("") + + if td_index in matched_index.keys(): + b_with = False + if ( + "" in ocr_contents[matched_index[td_index][0]] + and len(matched_index[td_index]) > 1 + ): + b_with = True + end_html.extend("") + + for i, td_index_index in enumerate(matched_index[td_index]): + content = ocr_contents[td_index_index][0] + if len(matched_index[td_index]) > 1: + if len(content) == 0: + continue + + if content[0] == " ": + content = content[1:] + + if "" in content: + content = content[3:] + + if "" in content: + content = content[:-4] + + if len(content) == 0: + continue + + if i != len(matched_index[td_index]) - 1 and " " != content[-1]: + content += " " + end_html.extend(content) + + if b_with: + end_html.extend("") + + if "" == tag: + end_html.append("") + else: + end_html.append(tag) + + td_index += 1 + + # Filter elements + filter_elements = ["", "", "", ""] + end_html = [v for v in end_html if v not in filter_elements] + return "".join(end_html), end_html + + def decode_logic_points(self, pred_structures): + logic_points = [] + current_row = 0 + current_col = 0 + max_rows = 0 + max_cols = 0 + occupied_cells = {} # 用于记录已经被占用的单元格 + + def is_occupied(row, col): + return (row, col) in occupied_cells + + def mark_occupied(row, col, rowspan, colspan): + for r in range(row, row + rowspan): + for c in range(col, col + colspan): + occupied_cells[(r, c)] = True + + i = 0 + while i < len(pred_structures): + token = pred_structures[i] + + if token == "": + current_col = 0 # 每次遇到 时,重置当前列号 + elif token == "": + current_row += 1 # 行结束,行号增加 + elif token.startswith(""): + if "colspan=" in pred_structures[j]: + colspan = int(pred_structures[j].split("=")[1].strip("\"'")) + elif "rowspan=" in pred_structures[j]: + rowspan = int(pred_structures[j].split("=")[1].strip("\"'")) + j += 1 + + # 跳过已经处理过的属性 token + i = j + + # 找到下一个未被占用的列 + while is_occupied(current_row, current_col): + current_col += 1 + + # 计算逻辑坐标 + r_start = current_row + r_end = current_row + rowspan - 1 + col_start = current_col + col_end = current_col + colspan - 1 + + # 记录逻辑坐标 + logic_points.append([r_start, r_end, col_start, col_end]) + + # 标记占用的单元格 + mark_occupied(r_start, col_start, rowspan, colspan) + + # 更新当前列号 + current_col += colspan + + # 更新最大行数和列数 + max_rows = max(max_rows, r_end + 1) + max_cols = max(max_cols, col_end + 1) + + i += 1 + + return logic_points + + def _filter_ocr_result(self, cell_bboxes, dt_boxes, rec_res): + y1 = cell_bboxes[:, 1::2].min() + new_dt_boxes = [] + new_rec_res = [] + + for box, rec in zip(dt_boxes, rec_res): + if np.max(box[1::2]) < y1: + continue + new_dt_boxes.append(box) + new_rec_res.append(rec) + return new_dt_boxes, new_rec_res diff --git a/helper/content_recognition/rapid_table_pipeline/table_matcher/utils.py b/helper/content_recognition/rapid_table_pipeline/table_matcher/utils.py new file mode 100644 index 0000000..57a613c --- /dev/null +++ b/helper/content_recognition/rapid_table_pipeline/table_matcher/utils.py @@ -0,0 +1,249 @@ +# copyright (c) 2022 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# -*- encoding: utf-8 -*- +# @Author: SWHL +# @Contact: liekkaskono@163.com +import copy +import re + + +def deal_isolate_span(thead_part): + """ + Deal with isolate span cases in this function. + It causes by wrong prediction in structure recognition model. + eg. predict to rowspan="2">. + :param thead_part: + :return: + """ + # 1. find out isolate span tokens. + isolate_pattern = ( + ' rowspan="(\d)+" colspan="(\d)+">|' + ' colspan="(\d)+" rowspan="(\d)+">|' + ' rowspan="(\d)+">|' + ' colspan="(\d)+">' + ) + isolate_iter = re.finditer(isolate_pattern, thead_part) + isolate_list = [i.group() for i in isolate_iter] + + # 2. find out span number, by step 1 results. + span_pattern = ( + ' rowspan="(\d)+" colspan="(\d)+"|' + ' colspan="(\d)+" rowspan="(\d)+"|' + ' rowspan="(\d)+"|' + ' colspan="(\d)+"' + ) + corrected_list = [] + for isolate_item in isolate_list: + span_part = re.search(span_pattern, isolate_item) + spanStr_in_isolateItem = span_part.group() + # 3. merge the span number into the span token format string. + if spanStr_in_isolateItem is not None: + corrected_item = f"" + corrected_list.append(corrected_item) + else: + corrected_list.append(None) + + # 4. replace original isolated token. + for corrected_item, isolate_item in zip(corrected_list, isolate_list): + if corrected_item is not None: + thead_part = thead_part.replace(isolate_item, corrected_item) + else: + pass + return thead_part + + +def deal_duplicate_bb(thead_part): + """ + Deal duplicate or after replace. + Keep one in a token. + :param thead_part: + :return: + """ + # 1. find out in . + td_pattern = ( + '(.+?)|' + '(.+?)|' + '(.+?)|' + '(.+?)|' + "(.*?)" + ) + td_iter = re.finditer(td_pattern, thead_part) + td_list = [t.group() for t in td_iter] + + # 2. is multiply in or not? + new_td_list = [] + for td_item in td_list: + if td_item.count("") > 1 or td_item.count("") > 1: + # multiply in case. + # 1. remove all + td_item = td_item.replace("", "").replace("", "") + # 2. replace -> , -> . + td_item = td_item.replace("", "").replace("", "") + new_td_list.append(td_item) + else: + new_td_list.append(td_item) + + # 3. replace original thead part. + for td_item, new_td_item in zip(td_list, new_td_list): + thead_part = thead_part.replace(td_item, new_td_item) + return thead_part + + +def deal_bb(result_token): + """ + In our opinion, always occurs in text's context. + This function will find out all tokens in and insert by manual. + :param result_token: + :return: + """ + # find out parts. + thead_pattern = "(.*?)" + if re.search(thead_pattern, result_token) is None: + return result_token + thead_part = re.search(thead_pattern, result_token).group() + origin_thead_part = copy.deepcopy(thead_part) + + # check "rowspan" or "colspan" occur in parts or not . + span_pattern = '|||' + span_iter = re.finditer(span_pattern, thead_part) + span_list = [s.group() for s in span_iter] + has_span_in_head = True if len(span_list) > 0 else False + + if not has_span_in_head: + # not include "rowspan" or "colspan" branch 1. + # 1. replace to , and to + # 2. it is possible to predict text include or by Text-line recognition, + # so we replace to , and to + thead_part = ( + thead_part.replace("", "") + .replace("", "") + .replace("", "") + .replace("", "") + ) + else: + # include "rowspan" or "colspan" branch 2. + # Firstly, we deal rowspan or colspan cases. + # 1. replace > to > + # 2. replace to + # 3. it is possible to predict text include or by Text-line recognition, + # so we replace to , and to + + # Secondly, deal ordinary cases like branch 1 + + # replace ">" to "" + replaced_span_list = [] + for sp in span_list: + replaced_span_list.append(sp.replace(">", ">")) + for sp, rsp in zip(span_list, replaced_span_list): + thead_part = thead_part.replace(sp, rsp) + + # replace "" to "" + thead_part = thead_part.replace("", "") + + # remove duplicated by re.sub + mb_pattern = "()+" + single_b_string = "" + thead_part = re.sub(mb_pattern, single_b_string, thead_part) + + mgb_pattern = "()+" + single_gb_string = "" + thead_part = re.sub(mgb_pattern, single_gb_string, thead_part) + + # ordinary cases like branch 1 + thead_part = thead_part.replace("", "").replace("", "") + + # convert back to , empty cell has no . + # but space cell( ) is suitable for + thead_part = thead_part.replace("", "") + # deal with duplicated + thead_part = deal_duplicate_bb(thead_part) + # deal with isolate span tokens, which causes by wrong predict by structure prediction. + # eg.PMC5994107_011_00.png + thead_part = deal_isolate_span(thead_part) + # replace original result with new thead part. + result_token = result_token.replace(origin_thead_part, thead_part) + return result_token + + +def deal_eb_token(master_token): + """ + post process with , , ... + emptyBboxTokenDict = { + "[]": '', + "[' ']": '', + "['', ' ', '']": '', + "['\\u2028', '\\u2028']": '', + "['', ' ', '']": '', + "['', '']": '', + "['', ' ', '']": '', + "['', '', '', '']": '', + "['', '', ' ', '', '']": '', + "['', '']": '', + "['', ' ', '\\u2028', ' ', '\\u2028', ' ', '']": '', + } + :param master_token: + :return: + """ + master_token = master_token.replace("", "") + master_token = master_token.replace("", " ") + master_token = master_token.replace("", " ") + master_token = master_token.replace("", "\u2028\u2028") + master_token = master_token.replace("", " ") + master_token = master_token.replace("", "") + master_token = master_token.replace("", " ") + master_token = master_token.replace("", "") + master_token = master_token.replace("", " ") + master_token = master_token.replace("", "") + master_token = master_token.replace( + "", " \u2028 \u2028 " + ) + return master_token + + +def distance(box_1, box_2): + x1, y1, x2, y2 = box_1 + x3, y3, x4, y4 = box_2 + dis = abs(x3 - x1) + abs(y3 - y1) + abs(x4 - x2) + abs(y4 - y2) + dis_2 = abs(x3 - x1) + abs(y3 - y1) + dis_3 = abs(x4 - x2) + abs(y4 - y2) + return dis + min(dis_2, dis_3) + + +def compute_iou(rec1, rec2): + """ + computing IoU + :param rec1: (y0, x0, y1, x1), which reflects + (top, left, bottom, right) + :param rec2: (y0, x0, y1, x1) + :return: scala value of IoU + """ + # computing area of each rectangles + S_rec1 = (rec1[2] - rec1[0]) * (rec1[3] - rec1[1]) + S_rec2 = (rec2[2] - rec2[0]) * (rec2[3] - rec2[1]) + + # computing the sum_area + sum_area = S_rec1 + S_rec2 + + # find the each edge of intersect rectangle + left_line = max(rec1[1], rec2[1]) + right_line = min(rec1[3], rec2[3]) + top_line = max(rec1[0], rec2[0]) + bottom_line = min(rec1[2], rec2[2]) + + # judge if there is an intersect + if left_line >= right_line or top_line >= bottom_line: + return 0.0 + + intersect = (right_line - left_line) * (bottom_line - top_line) + return (intersect / (sum_area - intersect)) * 1.0 diff --git a/helper/content_recognition/rapid_table_pipeline/table_structure/__init__.py b/helper/content_recognition/rapid_table_pipeline/table_structure/__init__.py new file mode 100644 index 0000000..3e638d9 --- /dev/null +++ b/helper/content_recognition/rapid_table_pipeline/table_structure/__init__.py @@ -0,0 +1,15 @@ +# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from .table_structure import TableStructurer +from .table_structure_unitable import TableStructureUnitable diff --git a/helper/content_recognition/rapid_table_pipeline/table_structure/table_structure.py b/helper/content_recognition/rapid_table_pipeline/table_structure/table_structure.py new file mode 100644 index 0000000..9152603 --- /dev/null +++ b/helper/content_recognition/rapid_table_pipeline/table_structure/table_structure.py @@ -0,0 +1,58 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import time +from typing import Any, Dict + +import numpy as np + +from .utils import OrtInferSession, TableLabelDecode, TablePreprocess + + +class TableStructurer: + def __init__(self, config: Dict[str, Any]): + self.preprocess_op = TablePreprocess() + + self.session = OrtInferSession(config) + + self.character = self.session.get_metadata() + self.postprocess_op = TableLabelDecode(self.character) + + def __call__(self, img): + starttime = time.time() + data = {"image": img} + data = self.preprocess_op(data) + img = data[0] + if img is None: + return None, 0 + img = np.expand_dims(img, axis=0) + img = img.copy() + + outputs = self.session([img]) + + preds = {"loc_preds": outputs[0], "structure_probs": outputs[1]} + + shape_list = np.expand_dims(data[-1], axis=0) + post_result = self.postprocess_op(preds, [shape_list]) + + bbox_list = post_result["bbox_batch_list"][0] + + structure_str_list = post_result["structure_batch_list"][0] + structure_str_list = structure_str_list[0] + structure_str_list = ( + ["", "", ""] + + structure_str_list + + ["
", "", ""] + ) + elapse = time.time() - starttime + return structure_str_list, bbox_list, elapse diff --git a/helper/content_recognition/rapid_table_pipeline/table_structure/table_structure_unitable.py b/helper/content_recognition/rapid_table_pipeline/table_structure/table_structure_unitable.py new file mode 100644 index 0000000..2b98006 --- /dev/null +++ b/helper/content_recognition/rapid_table_pipeline/table_structure/table_structure_unitable.py @@ -0,0 +1,229 @@ +import re +import time + +import cv2 +import numpy as np +import torch +from PIL import Image +from tokenizers import Tokenizer +from torchvision import transforms + +from .unitable_modules import Encoder, GPTFastDecoder + +IMG_SIZE = 448 +EOS_TOKEN = "" +BBOX_TOKENS = [f"bbox-{i}" for i in range(IMG_SIZE + 1)] + +HTML_BBOX_HTML_TOKENS = [ + "", + "[", + "]", + "[", + ">", + "", + "", + "", + "", + "", + "", + ' rowspan="2"', + ' rowspan="3"', + ' rowspan="4"', + ' rowspan="5"', + ' rowspan="6"', + ' rowspan="7"', + ' rowspan="8"', + ' rowspan="9"', + ' rowspan="10"', + ' rowspan="11"', + ' rowspan="12"', + ' rowspan="13"', + ' rowspan="14"', + ' rowspan="15"', + ' rowspan="16"', + ' rowspan="17"', + ' rowspan="18"', + ' rowspan="19"', + ' colspan="2"', + ' colspan="3"', + ' colspan="4"', + ' colspan="5"', + ' colspan="6"', + ' colspan="7"', + ' colspan="8"', + ' colspan="9"', + ' colspan="10"', + ' colspan="11"', + ' colspan="12"', + ' colspan="13"', + ' colspan="14"', + ' colspan="15"', + ' colspan="16"', + ' colspan="17"', + ' colspan="18"', + ' colspan="19"', + ' colspan="25"', +] + +VALID_HTML_BBOX_TOKENS = [EOS_TOKEN] + HTML_BBOX_HTML_TOKENS + BBOX_TOKENS +TASK_TOKENS = [ + "[table]", + "[html]", + "[cell]", + "[bbox]", + "[cell+bbox]", + "[html+bbox]", +] + + +class TableStructureUnitable: + def __init__(self, config): + # encoder_path: str, decoder_path: str, vocab_path: str, device: str + vocab_path = config["model_path"]["vocab"] + encoder_path = config["model_path"]["encoder"] + decoder_path = config["model_path"]["decoder"] + device = config.get("device", "cuda:0") if config["use_cuda"] else "cpu" + + self.vocab = Tokenizer.from_file(vocab_path) + self.token_white_list = [ + self.vocab.token_to_id(i) for i in VALID_HTML_BBOX_TOKENS + ] + self.bbox_token_ids = set(self.vocab.token_to_id(i) for i in BBOX_TOKENS) + self.bbox_close_html_token = self.vocab.token_to_id("]") + self.prefix_token_id = self.vocab.token_to_id("[html+bbox]") + self.eos_id = self.vocab.token_to_id(EOS_TOKEN) + self.max_seq_len = 1024 + self.device = device + self.img_size = IMG_SIZE + + # init encoder + encoder_state_dict = torch.load(encoder_path, map_location=device) + self.encoder = Encoder() + self.encoder.load_state_dict(encoder_state_dict) + self.encoder.eval().to(device) + + # init decoder + decoder_state_dict = torch.load(decoder_path, map_location=device) + self.decoder = GPTFastDecoder() + self.decoder.load_state_dict(decoder_state_dict) + self.decoder.eval().to(device) + + # define img transform + self.transform = transforms.Compose( + [ + transforms.Resize((448, 448)), + transforms.ToTensor(), + transforms.Normalize( + mean=[0.86597056, 0.88463002, 0.87491087], + std=[0.20686628, 0.18201602, 0.18485524], + ), + ] + ) + + @torch.inference_mode() + def __call__(self, image: np.ndarray): + start_time = time.time() + ori_h, ori_w = image.shape[:2] + image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) + image = Image.fromarray(image) + image = self.transform(image).unsqueeze(0).to(self.device) + self.decoder.setup_caches( + max_batch_size=1, + max_seq_length=self.max_seq_len, + dtype=image.dtype, + device=self.device, + ) + context = ( + torch.tensor([self.prefix_token_id], dtype=torch.int32) + .repeat(1, 1) + .to(self.device) + ) + eos_id_tensor = torch.tensor(self.eos_id, dtype=torch.int32).to(self.device) + memory = self.encoder(image) + context = self.loop_decode(context, eos_id_tensor, memory) + bboxes, html_tokens = self.decode_tokens(context) + bboxes = bboxes.astype(np.float32) + + # rescale boxes + scale_h = ori_h / self.img_size + scale_w = ori_w / self.img_size + bboxes[:, 0::2] *= scale_w # 缩放 x 坐标 + bboxes[:, 1::2] *= scale_h # 缩放 y 坐标 + bboxes[:, 0::2] = np.clip(bboxes[:, 0::2], 0, ori_w - 1) + bboxes[:, 1::2] = np.clip(bboxes[:, 1::2], 0, ori_h - 1) + structure_str_list = ( + ["", "", ""] + + html_tokens + + ["
", "", ""] + ) + return structure_str_list, bboxes, time.time() - start_time + + def decode_tokens(self, context): + pred_html = context[0] + pred_html = pred_html.detach().cpu().numpy() + pred_html = self.vocab.decode(pred_html, skip_special_tokens=False) + seq = pred_html.split("")[0] + token_black_list = ["", "", *TASK_TOKENS] + for i in token_black_list: + seq = seq.replace(i, "") + + tr_pattern = re.compile(r"(.*?)", re.DOTALL) + td_pattern = re.compile(r"(.*?)", re.DOTALL) + bbox_pattern = re.compile(r"\[ bbox-(\d+) bbox-(\d+) bbox-(\d+) bbox-(\d+) \]") + + decoded_list = [] + bbox_coords = [] + + # 查找所有的 标签 + for tr_match in tr_pattern.finditer(pred_html): + tr_content = tr_match.group(1) + decoded_list.append("") + + # 查找所有的 标签 + for td_match in td_pattern.finditer(tr_content): + td_attrs = td_match.group(1).strip() + td_content = td_match.group(2).strip() + if td_attrs: + decoded_list.append("") + decoded_list.append("") + else: + decoded_list.append("") + + # 查找 bbox 坐标 + bbox_match = bbox_pattern.search(td_content) + if bbox_match: + xmin, ymin, xmax, ymax = map(int, bbox_match.groups()) + # 将坐标转换为从左上角开始顺时针到左下角的点的坐标 + coords = np.array([xmin, ymin, xmax, ymin, xmax, ymax, xmin, ymax]) + bbox_coords.append(coords) + else: + # 填充占位的bbox,保证后续流程统一 + bbox_coords.append(np.array([0, 0, 0, 0, 0, 0, 0, 0])) + decoded_list.append("") + + bbox_coords_array = np.array(bbox_coords) + return bbox_coords_array, decoded_list + + def loop_decode(self, context, eos_id_tensor, memory): + box_token_count = 0 + for _ in range(self.max_seq_len): + eos_flag = (context == eos_id_tensor).any(dim=1) + if torch.all(eos_flag): + break + + next_tokens = self.decoder(memory, context) + if next_tokens[0] in self.bbox_token_ids: + box_token_count += 1 + if box_token_count > 4: + next_tokens = torch.tensor( + [self.bbox_close_html_token], dtype=torch.int32 + ) + box_token_count = 0 + context = torch.cat([context, next_tokens], dim=1) + return context diff --git a/helper/content_recognition/rapid_table_pipeline/table_structure/unitable_modules.py b/helper/content_recognition/rapid_table_pipeline/table_structure/unitable_modules.py new file mode 100644 index 0000000..5b8dac3 --- /dev/null +++ b/helper/content_recognition/rapid_table_pipeline/table_structure/unitable_modules.py @@ -0,0 +1,911 @@ +from dataclasses import dataclass +from functools import partial +from typing import Optional + +import torch +import torch.nn as nn +from torch import Tensor +from torch.nn import functional as F +from torch.nn.modules.transformer import _get_activation_fn + +TOKEN_WHITE_LIST = [ + 1, + 12, + 13, + 14, + 15, + 16, + 17, + 18, + 19, + 20, + 21, + 22, + 23, + 24, + 25, + 26, + 27, + 28, + 29, + 30, + 31, + 32, + 33, + 34, + 35, + 36, + 37, + 38, + 39, + 40, + 41, + 42, + 43, + 44, + 45, + 46, + 47, + 48, + 49, + 50, + 51, + 52, + 53, + 54, + 55, + 56, + 57, + 58, + 59, + 60, + 61, + 62, + 63, + 64, + 65, + 66, + 67, + 68, + 69, + 70, + 71, + 72, + 73, + 74, + 75, + 76, + 77, + 78, + 79, + 80, + 81, + 82, + 83, + 84, + 85, + 86, + 87, + 88, + 89, + 90, + 91, + 92, + 93, + 94, + 95, + 96, + 97, + 98, + 99, + 100, + 101, + 102, + 103, + 104, + 105, + 106, + 107, + 108, + 109, + 110, + 111, + 112, + 113, + 114, + 115, + 116, + 117, + 118, + 119, + 120, + 121, + 122, + 123, + 124, + 125, + 126, + 127, + 128, + 129, + 130, + 131, + 132, + 133, + 134, + 135, + 136, + 137, + 138, + 139, + 140, + 141, + 142, + 143, + 144, + 145, + 146, + 147, + 148, + 149, + 150, + 151, + 152, + 153, + 154, + 155, + 156, + 157, + 158, + 159, + 160, + 161, + 162, + 163, + 164, + 165, + 166, + 167, + 168, + 169, + 170, + 171, + 172, + 173, + 174, + 175, + 176, + 177, + 178, + 179, + 180, + 181, + 182, + 183, + 184, + 185, + 186, + 187, + 188, + 189, + 190, + 191, + 192, + 193, + 194, + 195, + 196, + 197, + 198, + 199, + 200, + 201, + 202, + 203, + 204, + 205, + 206, + 207, + 208, + 209, + 210, + 211, + 212, + 213, + 214, + 215, + 216, + 217, + 218, + 219, + 220, + 221, + 222, + 223, + 224, + 225, + 226, + 227, + 228, + 229, + 230, + 231, + 232, + 233, + 234, + 235, + 236, + 237, + 238, + 239, + 240, + 241, + 242, + 243, + 244, + 245, + 246, + 247, + 248, + 249, + 250, + 251, + 252, + 253, + 254, + 255, + 256, + 257, + 258, + 259, + 260, + 261, + 262, + 263, + 264, + 265, + 266, + 267, + 268, + 269, + 270, + 271, + 272, + 273, + 274, + 275, + 276, + 277, + 278, + 279, + 280, + 281, + 282, + 283, + 284, + 285, + 286, + 287, + 288, + 289, + 290, + 291, + 292, + 293, + 294, + 295, + 296, + 297, + 298, + 299, + 300, + 301, + 302, + 303, + 304, + 305, + 306, + 307, + 308, + 309, + 310, + 311, + 312, + 313, + 314, + 315, + 316, + 317, + 318, + 319, + 320, + 321, + 322, + 323, + 324, + 325, + 326, + 327, + 328, + 329, + 330, + 331, + 332, + 333, + 334, + 335, + 336, + 337, + 338, + 339, + 340, + 341, + 342, + 343, + 344, + 345, + 346, + 347, + 348, + 349, + 350, + 351, + 352, + 353, + 354, + 355, + 356, + 357, + 358, + 359, + 360, + 361, + 362, + 363, + 364, + 365, + 366, + 367, + 368, + 369, + 370, + 371, + 372, + 373, + 374, + 375, + 376, + 377, + 378, + 379, + 380, + 381, + 382, + 383, + 384, + 385, + 386, + 387, + 388, + 389, + 390, + 391, + 392, + 393, + 394, + 395, + 396, + 397, + 398, + 399, + 400, + 401, + 402, + 403, + 404, + 405, + 406, + 407, + 408, + 409, + 410, + 411, + 412, + 413, + 414, + 415, + 416, + 417, + 418, + 419, + 420, + 421, + 422, + 423, + 424, + 425, + 426, + 427, + 428, + 429, + 430, + 431, + 432, + 433, + 434, + 435, + 436, + 437, + 438, + 439, + 440, + 441, + 442, + 443, + 444, + 445, + 446, + 447, + 448, + 449, + 450, + 451, + 452, + 453, + 454, + 455, + 456, + 457, + 458, + 459, + 460, + 461, + 462, + 463, + 464, + 465, + 466, + 467, + 468, + 469, + 470, + 471, + 472, + 473, + 474, + 475, + 476, + 477, + 478, + 479, + 480, + 481, + 482, + 483, + 484, + 485, + 486, + 487, + 488, + 489, + 490, + 491, + 492, + 493, + 494, + 495, + 496, + 497, + 498, + 499, + 500, + 501, + 502, + 503, + 504, + 505, + 506, + 507, + 508, + 509, +] + + +class ImgLinearBackbone(nn.Module): + def __init__( + self, + d_model: int, + patch_size: int, + in_chan: int = 3, + ) -> None: + super().__init__() + + self.conv_proj = nn.Conv2d( + in_chan, + out_channels=d_model, + kernel_size=patch_size, + stride=patch_size, + ) + self.d_model = d_model + + def forward(self, x: Tensor) -> Tensor: + x = self.conv_proj(x) + x = x.flatten(start_dim=-2).transpose(1, 2) + return x + + +class Encoder(nn.Module): + def __init__(self) -> None: + super().__init__() + + self.patch_size = 16 + self.d_model = 768 + self.dropout = 0 + self.activation = "gelu" + self.norm_first = True + self.ff_ratio = 4 + self.nhead = 12 + self.max_seq_len = 1024 + self.n_encoder_layer = 12 + encoder_layer = nn.TransformerEncoderLayer( + self.d_model, + nhead=self.nhead, + dim_feedforward=self.ff_ratio * self.d_model, + dropout=self.dropout, + activation=self.activation, + batch_first=True, + norm_first=self.norm_first, + ) + norm_layer = partial(nn.LayerNorm, eps=1e-6) + self.norm = norm_layer(self.d_model) + self.backbone = ImgLinearBackbone( + d_model=self.d_model, patch_size=self.patch_size + ) + self.pos_embed = PositionEmbedding( + max_seq_len=self.max_seq_len, d_model=self.d_model, dropout=self.dropout + ) + self.encoder = nn.TransformerEncoder( + encoder_layer, num_layers=self.n_encoder_layer, enable_nested_tensor=False + ) + + def forward(self, x: Tensor) -> Tensor: + src_feature = self.backbone(x) + src_feature = self.pos_embed(src_feature) + memory = self.encoder(src_feature) + memory = self.norm(memory) + return memory + + +class PositionEmbedding(nn.Module): + def __init__(self, max_seq_len: int, d_model: int, dropout: float) -> None: + super().__init__() + self.embedding = nn.Embedding(max_seq_len, d_model) + self.dropout = nn.Dropout(dropout) + + def forward(self, x: Tensor, input_pos: Optional[Tensor] = None) -> Tensor: + # assume x is batch first + if input_pos is None: + _pos = torch.arange(x.shape[1], device=x.device) + else: + _pos = input_pos + out = self.embedding(_pos) + return self.dropout(out + x) + + +class TokenEmbedding(nn.Module): + def __init__( + self, + vocab_size: int, + d_model: int, + padding_idx: int, + ) -> None: + super().__init__() + assert vocab_size > 0 + self.embedding = nn.Embedding(vocab_size, d_model, padding_idx=padding_idx) + + def forward(self, x: Tensor) -> Tensor: + return self.embedding(x) + + +def find_multiple(n: int, k: int) -> int: + if n % k == 0: + return n + return n + k - (n % k) + + +@dataclass +class ModelArgs: + n_layer: int = 4 + n_head: int = 12 + dim: int = 768 + intermediate_size: int = None + head_dim: int = 64 + activation: str = "gelu" + norm_first: bool = True + + def __post_init__(self): + if self.intermediate_size is None: + hidden_dim = 4 * self.dim + n_hidden = int(2 * hidden_dim / 3) + self.intermediate_size = find_multiple(n_hidden, 256) + self.head_dim = self.dim // self.n_head + + +class KVCache(nn.Module): + def __init__( + self, + max_batch_size, + max_seq_length, + n_heads, + head_dim, + dtype=torch.bfloat16, + device="cpu", + ): + super().__init__() + cache_shape = (max_batch_size, n_heads, max_seq_length, head_dim) + self.register_buffer( + "k_cache", + torch.zeros(cache_shape, dtype=dtype, device=device), + persistent=False, + ) + self.register_buffer( + "v_cache", + torch.zeros(cache_shape, dtype=dtype, device=device), + persistent=False, + ) + + def update(self, input_pos, k_val, v_val): + # input_pos: [S], k_val: [B, H, S, D] + # assert input_pos.shape[0] == k_val.shape[2] + + bs = k_val.shape[0] + k_out = self.k_cache + v_out = self.v_cache + k_out[:bs, :, input_pos] = k_val + v_out[:bs, :, input_pos] = v_val + + return k_out[:bs], v_out[:bs] + + +class GPTFastDecoder(nn.Module): + def __init__(self) -> None: + super().__init__() + + self.vocab_size = 960 + self.padding_idx = 2 + self.prefix_token_id = 11 + self.eos_id = 1 + self.max_seq_len = 1024 + self.dropout = 0 + self.d_model = 768 + self.nhead = 12 + self.activation = "gelu" + self.norm_first = True + self.n_decoder_layer = 4 + config = ModelArgs( + n_layer=self.n_decoder_layer, + n_head=self.nhead, + dim=self.d_model, + intermediate_size=self.d_model * 4, + activation=self.activation, + norm_first=self.norm_first, + ) + self.config = config + self.layers = nn.ModuleList( + TransformerBlock(config) for _ in range(config.n_layer) + ) + self.token_embed = TokenEmbedding( + vocab_size=self.vocab_size, + d_model=self.d_model, + padding_idx=self.padding_idx, + ) + self.pos_embed = PositionEmbedding( + max_seq_len=self.max_seq_len, d_model=self.d_model, dropout=self.dropout + ) + self.generator = nn.Linear(self.d_model, self.vocab_size) + self.token_white_list = TOKEN_WHITE_LIST + self.mask_cache: Optional[Tensor] = None + self.max_batch_size = -1 + self.max_seq_length = -1 + + def setup_caches(self, max_batch_size, max_seq_length, dtype, device): + for b in self.layers: + b.multihead_attn.k_cache = None + b.multihead_attn.v_cache = None + + if ( + self.max_seq_length >= max_seq_length + and self.max_batch_size >= max_batch_size + ): + return + head_dim = self.config.dim // self.config.n_head + max_seq_length = find_multiple(max_seq_length, 8) + self.max_seq_length = max_seq_length + self.max_batch_size = max_batch_size + + for b in self.layers: + b.self_attn.kv_cache = KVCache( + max_batch_size, + max_seq_length, + self.config.n_head, + head_dim, + dtype, + device, + ) + b.multihead_attn.k_cache = None + b.multihead_attn.v_cache = None + + self.causal_mask = torch.tril( + torch.ones(self.max_seq_length, self.max_seq_length, dtype=torch.bool) + ).to(device) + + def forward(self, memory: Tensor, tgt: Tensor) -> Tensor: + input_pos = torch.tensor([tgt.shape[1] - 1], device=tgt.device, dtype=torch.int) + tgt = tgt[:, -1:] + tgt_feature = self.pos_embed(self.token_embed(tgt), input_pos=input_pos) + # tgt = self.decoder(tgt_feature, memory, input_pos) + with torch.backends.cuda.sdp_kernel( + enable_flash=False, enable_mem_efficient=False, enable_math=True + ): + logits = tgt_feature + tgt_mask = self.causal_mask[None, None, input_pos] + for i, layer in enumerate(self.layers): + logits = layer(logits, memory, input_pos=input_pos, tgt_mask=tgt_mask) + # return output + logits = self.generator(logits)[:, -1, :] + total = set([i for i in range(logits.shape[-1])]) + black_list = list(total.difference(set(self.token_white_list))) + logits[..., black_list] = -1e9 + probs = F.softmax(logits, dim=-1) + _, next_tokens = probs.topk(1) + return next_tokens + + +class TransformerBlock(nn.Module): + def __init__(self, config: ModelArgs) -> None: + super().__init__() + self.self_attn = Attention(config) + self.multihead_attn = CrossAttention(config) + + layer_norm_eps = 1e-5 + + d_model = config.dim + dim_feedforward = config.intermediate_size + + self.linear1 = nn.Linear(d_model, dim_feedforward) + self.linear2 = nn.Linear(dim_feedforward, d_model) + + self.norm_first = config.norm_first + self.norm1 = nn.LayerNorm(d_model, eps=layer_norm_eps) + self.norm2 = nn.LayerNorm(d_model, eps=layer_norm_eps) + self.norm3 = nn.LayerNorm(d_model, eps=layer_norm_eps) + + self.activation = _get_activation_fn(config.activation) + + def forward( + self, + tgt: Tensor, + memory: Tensor, + tgt_mask: Tensor, + input_pos: Tensor, + ) -> Tensor: + if self.norm_first: + x = tgt + x = x + self.self_attn(self.norm1(x), tgt_mask, input_pos) + x = x + self.multihead_attn(self.norm2(x), memory) + x = x + self._ff_block(self.norm3(x)) + else: + x = tgt + x = self.norm1(x + self.self_attn(x, tgt_mask, input_pos)) + x = self.norm2(x + self.multihead_attn(x, memory)) + x = self.norm3(x + self._ff_block(x)) + return x + + def _ff_block(self, x: Tensor) -> Tensor: + x = self.linear2(self.activation(self.linear1(x))) + return x + + +class Attention(nn.Module): + def __init__(self, config: ModelArgs): + super().__init__() + assert config.dim % config.n_head == 0 + + # key, query, value projections for all heads, but in a batch + self.wqkv = nn.Linear(config.dim, 3 * config.dim) + self.wo = nn.Linear(config.dim, config.dim) + + self.kv_cache: Optional[KVCache] = None + + self.n_head = config.n_head + self.head_dim = config.head_dim + self.dim = config.dim + + def forward( + self, + x: Tensor, + mask: Tensor, + input_pos: Optional[Tensor] = None, + ) -> Tensor: + bsz, seqlen, _ = x.shape + + kv_size = self.n_head * self.head_dim + q, k, v = self.wqkv(x).split([self.dim, kv_size, kv_size], dim=-1) + + q = q.view(bsz, seqlen, self.n_head, self.head_dim) + k = k.view(bsz, seqlen, self.n_head, self.head_dim) + v = v.view(bsz, seqlen, self.n_head, self.head_dim) + + q, k, v = map(lambda x: x.transpose(1, 2), (q, k, v)) + + if self.kv_cache is not None: + k, v = self.kv_cache.update(input_pos, k, v) + + y = F.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0) + + y = y.transpose(1, 2).contiguous().view(bsz, seqlen, self.dim) + + y = self.wo(y) + return y + + +class CrossAttention(nn.Module): + def __init__(self, config: ModelArgs): + super().__init__() + assert config.dim % config.n_head == 0 + + self.query = nn.Linear(config.dim, config.dim) + self.key = nn.Linear(config.dim, config.dim) + self.value = nn.Linear(config.dim, config.dim) + self.out = nn.Linear(config.dim, config.dim) + + self.k_cache = None + self.v_cache = None + + self.n_head = config.n_head + self.head_dim = config.head_dim + + def get_kv(self, xa: torch.Tensor): + if self.k_cache is not None and self.v_cache is not None: + return self.k_cache, self.v_cache + + k = self.key(xa) + v = self.value(xa) + + # Reshape for correct format + batch_size, source_seq_len, _ = k.shape + k = k.view(batch_size, source_seq_len, self.n_head, self.head_dim) + v = v.view(batch_size, source_seq_len, self.n_head, self.head_dim) + + if self.k_cache is None: + self.k_cache = k + if self.v_cache is None: + self.v_cache = v + + return k, v + + def forward( + self, + x: Tensor, + xa: Tensor, + ): + q = self.query(x) + batch_size, target_seq_len, _ = q.shape + q = q.view(batch_size, target_seq_len, self.n_head, self.head_dim) + k, v = self.get_kv(xa) + + q, k, v = map(lambda x: x.transpose(1, 2), (q, k, v)) + + wv = F.scaled_dot_product_attention( + query=q, + key=k, + value=v, + is_causal=False, + ) + wv = wv.transpose(1, 2).reshape( + batch_size, + target_seq_len, + self.n_head * self.head_dim, + ) + + return self.out(wv) diff --git a/helper/content_recognition/rapid_table_pipeline/table_structure/utils.py b/helper/content_recognition/rapid_table_pipeline/table_structure/utils.py new file mode 100644 index 0000000..484a63d --- /dev/null +++ b/helper/content_recognition/rapid_table_pipeline/table_structure/utils.py @@ -0,0 +1,544 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# -*- encoding: utf-8 -*- +# @Author: SWHL +# @Contact: liekkaskono@163.com +import os +import platform +import traceback +from enum import Enum +from pathlib import Path +from typing import Any, Dict, List, Tuple, Union + +import cv2 +import numpy as np +from onnxruntime import ( + GraphOptimizationLevel, + InferenceSession, + SessionOptions, + get_available_providers, + get_device, +) + +from rapid_table.utils import Logger + + +class EP(Enum): + CPU_EP = "CPUExecutionProvider" + CUDA_EP = "CUDAExecutionProvider" + DIRECTML_EP = "DmlExecutionProvider" + + +class OrtInferSession: + def __init__(self, config: Dict[str, Any]): + self.logger = Logger(logger_name=__name__).get_log() + + model_path = config.get("model_path", None) + self._verify_model(model_path) + + self.cfg_use_cuda = config.get("use_cuda", None) + self.cfg_use_dml = config.get("use_dml", None) + + self.had_providers: List[str] = get_available_providers() + EP_list = self._get_ep_list() + + sess_opt = self._init_sess_opts(config) + self.session = InferenceSession( + model_path, + sess_options=sess_opt, + providers=EP_list, + ) + self._verify_providers() + + @staticmethod + def _init_sess_opts(config: Dict[str, Any]) -> SessionOptions: + sess_opt = SessionOptions() + sess_opt.log_severity_level = 4 + sess_opt.enable_cpu_mem_arena = False + sess_opt.graph_optimization_level = GraphOptimizationLevel.ORT_ENABLE_ALL + + cpu_nums = os.cpu_count() + intra_op_num_threads = config.get("intra_op_num_threads", -1) + if intra_op_num_threads != -1 and 1 <= intra_op_num_threads <= cpu_nums: + sess_opt.intra_op_num_threads = intra_op_num_threads + + inter_op_num_threads = config.get("inter_op_num_threads", -1) + if inter_op_num_threads != -1 and 1 <= inter_op_num_threads <= cpu_nums: + sess_opt.inter_op_num_threads = inter_op_num_threads + + return sess_opt + + def get_metadata(self, key: str = "character") -> list: + meta_dict = self.session.get_modelmeta().custom_metadata_map + content_list = meta_dict[key].splitlines() + return content_list + + def _get_ep_list(self) -> List[Tuple[str, Dict[str, Any]]]: + cpu_provider_opts = { + "arena_extend_strategy": "kSameAsRequested", + } + EP_list = [(EP.CPU_EP.value, cpu_provider_opts)] + + cuda_provider_opts = { + "device_id": 0, + "arena_extend_strategy": "kNextPowerOfTwo", + "cudnn_conv_algo_search": "EXHAUSTIVE", + "do_copy_in_default_stream": True, + } + self.use_cuda = self._check_cuda() + if self.use_cuda: + EP_list.insert(0, (EP.CUDA_EP.value, cuda_provider_opts)) + + self.use_directml = self._check_dml() + if self.use_directml: + self.logger.info( + "Windows 10 or above detected, try to use DirectML as primary provider" + ) + directml_options = ( + cuda_provider_opts if self.use_cuda else cpu_provider_opts + ) + EP_list.insert(0, (EP.DIRECTML_EP.value, directml_options)) + return EP_list + + def _check_cuda(self) -> bool: + if not self.cfg_use_cuda: + return False + + cur_device = get_device() + if cur_device == "GPU" and EP.CUDA_EP.value in self.had_providers: + return True + + self.logger.warning( + "%s is not in available providers (%s). Use %s inference by default.", + EP.CUDA_EP.value, + self.had_providers, + self.had_providers[0], + ) + self.logger.info("!!!Recommend to use rapidocr_paddle for inference on GPU.") + self.logger.info( + "(For reference only) If you want to use GPU acceleration, you must do:" + ) + self.logger.info( + "First, uninstall all onnxruntime pakcages in current environment." + ) + self.logger.info( + "Second, install onnxruntime-gpu by `pip install onnxruntime-gpu`." + ) + self.logger.info( + "\tNote the onnxruntime-gpu version must match your cuda and cudnn version." + ) + self.logger.info( + "\tYou can refer this link: https://onnxruntime.ai/docs/execution-providers/CUDA-EP.html" + ) + self.logger.info( + "Third, ensure %s is in available providers list. e.g. ['CUDAExecutionProvider', 'CPUExecutionProvider']", + EP.CUDA_EP.value, + ) + return False + + def _check_dml(self) -> bool: + if not self.cfg_use_dml: + return False + + cur_os = platform.system() + if cur_os != "Windows": + self.logger.warning( + "DirectML is only supported in Windows OS. The current OS is %s. Use %s inference by default.", + cur_os, + self.had_providers[0], + ) + return False + + cur_window_version = int(platform.release().split(".")[0]) + if cur_window_version < 10: + self.logger.warning( + "DirectML is only supported in Windows 10 and above OS. The current Windows version is %s. Use %s inference by default.", + cur_window_version, + self.had_providers[0], + ) + return False + + if EP.DIRECTML_EP.value in self.had_providers: + return True + + self.logger.warning( + "%s is not in available providers (%s). Use %s inference by default.", + EP.DIRECTML_EP.value, + self.had_providers, + self.had_providers[0], + ) + self.logger.info("If you want to use DirectML acceleration, you must do:") + self.logger.info( + "First, uninstall all onnxruntime pakcages in current environment." + ) + self.logger.info( + "Second, install onnxruntime-directml by `pip install onnxruntime-directml`" + ) + self.logger.info( + "Third, ensure %s is in available providers list. e.g. ['DmlExecutionProvider', 'CPUExecutionProvider']", + EP.DIRECTML_EP.value, + ) + return False + + def _verify_providers(self): + session_providers = self.session.get_providers() + first_provider = session_providers[0] + + if self.use_cuda and first_provider != EP.CUDA_EP.value: + self.logger.warning( + "%s is not avaiable for current env, the inference part is automatically shifted to be executed under %s.", + EP.CUDA_EP.value, + first_provider, + ) + + if self.use_directml and first_provider != EP.DIRECTML_EP.value: + self.logger.warning( + "%s is not available for current env, the inference part is automatically shifted to be executed under %s.", + EP.DIRECTML_EP.value, + first_provider, + ) + + def __call__(self, input_content: List[np.ndarray]) -> np.ndarray: + input_dict = dict(zip(self.get_input_names(), input_content)) + try: + return self.session.run(None, input_dict) + except Exception as e: + error_info = traceback.format_exc() + raise ONNXRuntimeError(error_info) from e + + def get_input_names(self) -> List[str]: + return [v.name for v in self.session.get_inputs()] + + def get_output_names(self) -> List[str]: + return [v.name for v in self.session.get_outputs()] + + def get_character_list(self, key: str = "character") -> List[str]: + meta_dict = self.session.get_modelmeta().custom_metadata_map + return meta_dict[key].splitlines() + + def have_key(self, key: str = "character") -> bool: + meta_dict = self.session.get_modelmeta().custom_metadata_map + if key in meta_dict.keys(): + return True + return False + + @staticmethod + def _verify_model(model_path: Union[str, Path, None]): + if model_path is None: + raise ValueError("model_path is None!") + + model_path = Path(model_path) + if not model_path.exists(): + raise FileNotFoundError(f"{model_path} does not exists.") + + if not model_path.is_file(): + raise FileExistsError(f"{model_path} is not a file.") + + +class ONNXRuntimeError(Exception): + pass + + +class TableLabelDecode: + def __init__(self, dict_character, merge_no_span_structure=True, **kwargs): + if merge_no_span_structure: + if "" not in dict_character: + dict_character.append("") + if "" in dict_character: + dict_character.remove("") + + dict_character = self.add_special_char(dict_character) + self.dict = {} + for i, char in enumerate(dict_character): + self.dict[char] = i + self.character = dict_character + self.td_token = ["", ""] + + def __call__(self, preds, batch=None): + structure_probs = preds["structure_probs"] + bbox_preds = preds["loc_preds"] + shape_list = batch[-1] + result = self.decode(structure_probs, bbox_preds, shape_list) + if len(batch) == 1: # only contains shape + return result + + label_decode_result = self.decode_label(batch) + return result, label_decode_result + + def decode(self, structure_probs, bbox_preds, shape_list): + """convert text-label into text-index.""" + ignored_tokens = self.get_ignored_tokens() + end_idx = self.dict[self.end_str] + + structure_idx = structure_probs.argmax(axis=2) + structure_probs = structure_probs.max(axis=2) + + structure_batch_list = [] + bbox_batch_list = [] + batch_size = len(structure_idx) + for batch_idx in range(batch_size): + structure_list = [] + bbox_list = [] + score_list = [] + for idx in range(len(structure_idx[batch_idx])): + char_idx = int(structure_idx[batch_idx][idx]) + if idx > 0 and char_idx == end_idx: + break + + if char_idx in ignored_tokens: + continue + + text = self.character[char_idx] + if text in self.td_token: + bbox = bbox_preds[batch_idx, idx] + bbox = self._bbox_decode(bbox, shape_list[batch_idx]) + bbox_list.append(bbox) + structure_list.append(text) + score_list.append(structure_probs[batch_idx, idx]) + structure_batch_list.append([structure_list, np.mean(score_list)]) + bbox_batch_list.append(np.array(bbox_list)) + result = { + "bbox_batch_list": bbox_batch_list, + "structure_batch_list": structure_batch_list, + } + return result + + def decode_label(self, batch): + """convert text-label into text-index.""" + structure_idx = batch[1] + gt_bbox_list = batch[2] + shape_list = batch[-1] + ignored_tokens = self.get_ignored_tokens() + end_idx = self.dict[self.end_str] + + structure_batch_list = [] + bbox_batch_list = [] + batch_size = len(structure_idx) + for batch_idx in range(batch_size): + structure_list = [] + bbox_list = [] + for idx in range(len(structure_idx[batch_idx])): + char_idx = int(structure_idx[batch_idx][idx]) + if idx > 0 and char_idx == end_idx: + break + + if char_idx in ignored_tokens: + continue + + structure_list.append(self.character[char_idx]) + + bbox = gt_bbox_list[batch_idx][idx] + if bbox.sum() != 0: + bbox = self._bbox_decode(bbox, shape_list[batch_idx]) + bbox_list.append(bbox) + + structure_batch_list.append(structure_list) + bbox_batch_list.append(bbox_list) + result = { + "bbox_batch_list": bbox_batch_list, + "structure_batch_list": structure_batch_list, + } + return result + + def _bbox_decode(self, bbox, shape): + h, w = shape[:2] + bbox[0::2] *= w + bbox[1::2] *= h + return bbox + + def get_ignored_tokens(self): + beg_idx = self.get_beg_end_flag_idx("beg") + end_idx = self.get_beg_end_flag_idx("end") + return [beg_idx, end_idx] + + def get_beg_end_flag_idx(self, beg_or_end): + if beg_or_end == "beg": + return np.array(self.dict[self.beg_str]) + + if beg_or_end == "end": + return np.array(self.dict[self.end_str]) + + raise TypeError(f"unsupport type {beg_or_end} in get_beg_end_flag_idx") + + def add_special_char(self, dict_character): + self.beg_str = "sos" + self.end_str = "eos" + dict_character = [self.beg_str] + dict_character + [self.end_str] + return dict_character + + +class TablePreprocess: + def __init__(self): + self.table_max_len = 488 + self.build_pre_process_list() + self.ops = self.create_operators() + + def __call__(self, data): + """transform""" + if self.ops is None: + self.ops = [] + + for op in self.ops: + data = op(data) + if data is None: + return None + return data + + def create_operators( + self, + ): + """ + create operators based on the config + + Args: + params(list): a dict list, used to create some operators + """ + assert isinstance( + self.pre_process_list, list + ), "operator config should be a list" + ops = [] + for operator in self.pre_process_list: + assert ( + isinstance(operator, dict) and len(operator) == 1 + ), "yaml format error" + op_name = list(operator)[0] + param = {} if operator[op_name] is None else operator[op_name] + op = eval(op_name)(**param) + ops.append(op) + return ops + + def build_pre_process_list(self): + resize_op = { + "ResizeTableImage": { + "max_len": self.table_max_len, + } + } + pad_op = { + "PaddingTableImage": {"size": [self.table_max_len, self.table_max_len]} + } + normalize_op = { + "NormalizeImage": { + "std": [0.229, 0.224, 0.225], + "mean": [0.485, 0.456, 0.406], + "scale": "1./255.", + "order": "hwc", + } + } + to_chw_op = {"ToCHWImage": None} + keep_keys_op = {"KeepKeys": {"keep_keys": ["image", "shape"]}} + self.pre_process_list = [ + resize_op, + normalize_op, + pad_op, + to_chw_op, + keep_keys_op, + ] + + +class ResizeTableImage: + def __init__(self, max_len, resize_bboxes=False, infer_mode=False): + super(ResizeTableImage, self).__init__() + self.max_len = max_len + self.resize_bboxes = resize_bboxes + self.infer_mode = infer_mode + + def __call__(self, data): + img = data["image"] + height, width = img.shape[0:2] + ratio = self.max_len / (max(height, width) * 1.0) + resize_h = int(height * ratio) + resize_w = int(width * ratio) + resize_img = cv2.resize(img, (resize_w, resize_h)) + if self.resize_bboxes and not self.infer_mode: + data["bboxes"] = data["bboxes"] * ratio + data["image"] = resize_img + data["src_img"] = img + data["shape"] = np.array([height, width, ratio, ratio]) + data["max_len"] = self.max_len + return data + + +class PaddingTableImage: + def __init__(self, size, **kwargs): + super(PaddingTableImage, self).__init__() + self.size = size + + def __call__(self, data): + img = data["image"] + pad_h, pad_w = self.size + padding_img = np.zeros((pad_h, pad_w, 3), dtype=np.float32) + height, width = img.shape[0:2] + padding_img[0:height, 0:width, :] = img.copy() + data["image"] = padding_img + shape = data["shape"].tolist() + shape.extend([pad_h, pad_w]) + data["shape"] = np.array(shape) + return data + + +class NormalizeImage: + """normalize image such as substract mean, divide std""" + + def __init__(self, scale=None, mean=None, std=None, order="chw", **kwargs): + if isinstance(scale, str): + scale = eval(scale) + self.scale = np.float32(scale if scale is not None else 1.0 / 255.0) + mean = mean if mean is not None else [0.485, 0.456, 0.406] + std = std if std is not None else [0.229, 0.224, 0.225] + + shape = (3, 1, 1) if order == "chw" else (1, 1, 3) + self.mean = np.array(mean).reshape(shape).astype("float32") + self.std = np.array(std).reshape(shape).astype("float32") + + def __call__(self, data): + img = np.array(data["image"]) + assert isinstance(img, np.ndarray), "invalid input 'img' in NormalizeImage" + data["image"] = (img.astype("float32") * self.scale - self.mean) / self.std + return data + + +class ToCHWImage: + """convert hwc image to chw image""" + + def __init__(self, **kwargs): + pass + + def __call__(self, data): + img = np.array(data["image"]) + data["image"] = img.transpose((2, 0, 1)) + return data + + +class KeepKeys: + def __init__(self, keep_keys, **kwargs): + self.keep_keys = keep_keys + + def __call__(self, data): + data_list = [] + for key in self.keep_keys: + data_list.append(data[key]) + return data_list + + +def trans_char_ocr_res(ocr_res): + word_result = [] + for res in ocr_res: + score = res[2] + for word_box, word in zip(res[3], res[4]): + word_res = [] + word_res.append(word_box) + word_res.append(word) + word_res.append(score) + word_result.append(word_res) + return word_result diff --git a/helper/content_recognition/rapid_table_pipeline/utils/__init__.py b/helper/content_recognition/rapid_table_pipeline/utils/__init__.py new file mode 100644 index 0000000..8754555 --- /dev/null +++ b/helper/content_recognition/rapid_table_pipeline/utils/__init__.py @@ -0,0 +1,8 @@ +# -*- encoding: utf-8 -*- +# @Author: SWHL +# @Contact: liekkaskono@163.com +from .download_model import DownloadModel +from .load_image import LoadImage +from .logger import Logger +from .utils import is_url +from .vis import VisTable diff --git a/helper/content_recognition/rapid_table_pipeline/utils/download_model.py b/helper/content_recognition/rapid_table_pipeline/utils/download_model.py new file mode 100644 index 0000000..7d35a88 --- /dev/null +++ b/helper/content_recognition/rapid_table_pipeline/utils/download_model.py @@ -0,0 +1,67 @@ +import io +from pathlib import Path +from typing import Optional, Union + +import requests +from tqdm import tqdm + +from .logger import Logger + +PROJECT_DIR = Path(__file__).resolve().parent.parent +DEFAULT_MODEL_DIR = PROJECT_DIR / "models" + + +class DownloadModel: + logger = Logger(logger_name=__name__).get_log() + + @classmethod + def download( + cls, + model_full_url: Union[str, Path], + save_dir: Union[str, Path, None] = None, + save_model_name: Optional[str] = None, + ) -> str: + if save_dir is None: + save_dir = DEFAULT_MODEL_DIR + + save_dir.mkdir(parents=True, exist_ok=True) + + if save_model_name is None: + save_model_name = Path(model_full_url).name + + save_file_path = save_dir / save_model_name + if save_file_path.exists(): + cls.logger.info("%s already exists", save_file_path) + return str(save_file_path) + + try: + cls.logger.info("Download %s to %s", model_full_url, save_dir) + file = cls.download_as_bytes_with_progress(model_full_url, save_model_name) + cls.save_file(save_file_path, file) + except Exception as exc: + raise DownloadModelError from exc + return str(save_file_path) + + @staticmethod + def download_as_bytes_with_progress( + url: Union[str, Path], name: Optional[str] = None + ) -> bytes: + resp = requests.get(str(url), stream=True, allow_redirects=True, timeout=180) + total = int(resp.headers.get("content-length", 0)) + bio = io.BytesIO() + with tqdm( + desc=name, total=total, unit="b", unit_scale=True, unit_divisor=1024 + ) as pbar: + for chunk in resp.iter_content(chunk_size=65536): + pbar.update(len(chunk)) + bio.write(chunk) + return bio.getvalue() + + @staticmethod + def save_file(save_path: Union[str, Path], file: bytes): + with open(save_path, "wb") as f: + f.write(file) + + +class DownloadModelError(Exception): + pass diff --git a/helper/content_recognition/rapid_table_pipeline/utils/load_image.py b/helper/content_recognition/rapid_table_pipeline/utils/load_image.py new file mode 100644 index 0000000..86ab953 --- /dev/null +++ b/helper/content_recognition/rapid_table_pipeline/utils/load_image.py @@ -0,0 +1,131 @@ +# -*- encoding: utf-8 -*- +# @Author: SWHL +# @Contact: liekkaskono@163.com +from io import BytesIO +from pathlib import Path +from typing import Any, Union + +import cv2 +import numpy as np +import requests +from PIL import Image, UnidentifiedImageError + +from .utils import is_url + +root_dir = Path(__file__).resolve().parent +InputType = Union[str, np.ndarray, bytes, Path, Image.Image] + + +class LoadImage: + def __init__(self): + pass + + def __call__(self, img: InputType) -> np.ndarray: + if not isinstance(img, InputType.__args__): + raise LoadImageError( + f"The img type {type(img)} does not in {InputType.__args__}" + ) + + origin_img_type = type(img) + img = self.load_img(img) + img = self.convert_img(img, origin_img_type) + return img + + def load_img(self, img: InputType) -> np.ndarray: + if isinstance(img, (str, Path)): + if is_url(img): + img = Image.open(requests.get(img, stream=True, timeout=60).raw) + else: + self.verify_exist(img) + img = Image.open(img) + + try: + img = self.img_to_ndarray(img) + except UnidentifiedImageError as e: + raise LoadImageError(f"cannot identify image file {img}") from e + return img + + if isinstance(img, bytes): + img = self.img_to_ndarray(Image.open(BytesIO(img))) + return img + + if isinstance(img, np.ndarray): + return img + + if isinstance(img, Image.Image): + return self.img_to_ndarray(img) + + raise LoadImageError(f"{type(img)} is not supported!") + + def img_to_ndarray(self, img: Image.Image) -> np.ndarray: + if img.mode == "1": + img = img.convert("L") + return np.array(img) + return np.array(img) + + def convert_img(self, img: np.ndarray, origin_img_type: Any) -> np.ndarray: + if img.ndim == 2: + return cv2.cvtColor(img, cv2.COLOR_GRAY2BGR) + + if img.ndim == 3: + channel = img.shape[2] + if channel == 1: + return cv2.cvtColor(img, cv2.COLOR_GRAY2BGR) + + if channel == 2: + return self.cvt_two_to_three(img) + + if channel == 3: + if issubclass(origin_img_type, (str, Path, bytes, Image.Image)): + return cv2.cvtColor(img, cv2.COLOR_RGB2BGR) + return img + + if channel == 4: + return self.cvt_four_to_three(img) + + raise LoadImageError( + f"The channel({channel}) of the img is not in [1, 2, 3, 4]" + ) + + raise LoadImageError(f"The ndim({img.ndim}) of the img is not in [2, 3]") + + @staticmethod + def cvt_two_to_three(img: np.ndarray) -> np.ndarray: + """gray + alpha → BGR""" + img_gray = img[..., 0] + img_bgr = cv2.cvtColor(img_gray, cv2.COLOR_GRAY2BGR) + + img_alpha = img[..., 1] + not_a = cv2.bitwise_not(img_alpha) + not_a = cv2.cvtColor(not_a, cv2.COLOR_GRAY2BGR) + + new_img = cv2.bitwise_and(img_bgr, img_bgr, mask=img_alpha) + new_img = cv2.add(new_img, not_a) + return new_img + + @staticmethod + def cvt_four_to_three(img: np.ndarray) -> np.ndarray: + """RGBA → BGR""" + r, g, b, a = cv2.split(img) + new_img = cv2.merge((b, g, r)) + + not_a = cv2.bitwise_not(a) + not_a = cv2.cvtColor(not_a, cv2.COLOR_GRAY2BGR) + + new_img = cv2.bitwise_and(new_img, new_img, mask=a) + + mean_color = np.mean(new_img) + if mean_color <= 0.0: + new_img = cv2.add(new_img, not_a) + else: + new_img = cv2.bitwise_not(new_img) + return new_img + + @staticmethod + def verify_exist(file_path: Union[str, Path]): + if not Path(file_path).exists(): + raise LoadImageError(f"{file_path} does not exist.") + + +class LoadImageError(Exception): + pass diff --git a/helper/content_recognition/rapid_table_pipeline/utils/logger.py b/helper/content_recognition/rapid_table_pipeline/utils/logger.py new file mode 100644 index 0000000..d3be7fe --- /dev/null +++ b/helper/content_recognition/rapid_table_pipeline/utils/logger.py @@ -0,0 +1,37 @@ +# -*- encoding: utf-8 -*- +# @Author: Jocker1212 +# @Contact: xinyijianggo@gmail.com +import logging + +import colorlog + + +class Logger: + def __init__(self, log_level=logging.DEBUG, logger_name=None): + self.logger = logging.getLogger(logger_name) + self.logger.setLevel(log_level) + self.logger.propagate = False + + formatter = colorlog.ColoredFormatter( + "%(log_color)s[%(levelname)s] %(asctime)s [RapidTable] %(filename)s:%(lineno)d: %(message)s", + log_colors={ + "DEBUG": "cyan", + "INFO": "green", + "WARNING": "yellow", + "ERROR": "red", + "CRITICAL": "red,bg_white", + }, + ) + + if not self.logger.handlers: + console_handler = logging.StreamHandler() + console_handler.setFormatter(formatter) + + for handler in self.logger.handlers: + self.logger.removeHandler(handler) + + console_handler.setLevel(log_level) + self.logger.addHandler(console_handler) + + def get_log(self): + return self.logger diff --git a/helper/content_recognition/rapid_table_pipeline/utils/utils.py b/helper/content_recognition/rapid_table_pipeline/utils/utils.py new file mode 100644 index 0000000..a182929 --- /dev/null +++ b/helper/content_recognition/rapid_table_pipeline/utils/utils.py @@ -0,0 +1,12 @@ +# -*- encoding: utf-8 -*- +# @Author: SWHL +# @Contact: liekkaskono@163.com +from urllib.parse import urlparse + + +def is_url(url: str) -> bool: + try: + result = urlparse(url) + return all([result.scheme, result.netloc]) + except Exception: + return False diff --git a/helper/content_recognition/rapid_table_pipeline/utils/vis.py b/helper/content_recognition/rapid_table_pipeline/utils/vis.py new file mode 100644 index 0000000..88fc69b --- /dev/null +++ b/helper/content_recognition/rapid_table_pipeline/utils/vis.py @@ -0,0 +1,145 @@ +# -*- encoding: utf-8 -*- +# @Author: SWHL +# @Contact: liekkaskono@163.com +import os +from pathlib import Path +from typing import Optional, Union + +import cv2 +import numpy as np + +from .load_image import LoadImage + + +class VisTable: + def __init__(self): + self.load_img = LoadImage() + + def __call__( + self, + img_path: Union[str, Path], + table_results, + save_html_path: Optional[str] = None, + save_drawed_path: Optional[str] = None, + save_logic_path: Optional[str] = None, + ): + if save_html_path: + html_with_border = self.insert_border_style(table_results.pred_html) + self.save_html(save_html_path, html_with_border) + + table_cell_bboxes = table_results.cell_bboxes + if table_cell_bboxes is None: + return None + + img = self.load_img(img_path) + + dims_bboxes = table_cell_bboxes.shape[1] + if dims_bboxes == 4: + drawed_img = self.draw_rectangle(img, table_cell_bboxes) + elif dims_bboxes == 8: + drawed_img = self.draw_polylines(img, table_cell_bboxes) + else: + raise ValueError("Shape of table bounding boxes is not between in 4 or 8.") + + if save_drawed_path: + self.save_img(save_drawed_path, drawed_img) + + if save_logic_path and table_results.logic_points: + polygons = [[box[0], box[1], box[4], box[5]] for box in table_cell_bboxes] + self.plot_rec_box_with_logic_info( + img, save_logic_path, table_results.logic_points, polygons + ) + return drawed_img + + def insert_border_style(self, table_html_str: str): + style_res = """""" + + prefix_table, suffix_table = table_html_str.split("") + html_with_border = f"{prefix_table}{style_res}{suffix_table}" + return html_with_border + + def plot_rec_box_with_logic_info( + self, img: np.ndarray, output_path, logic_points, sorted_polygons + ): + """ + :param img_path + :param output_path + :param logic_points: [row_start,row_end,col_start,col_end] + :param sorted_polygons: [xmin,ymin,xmax,ymax] + :return: + """ + # 读取原图 + img = cv2.copyMakeBorder( + img, 0, 0, 0, 100, cv2.BORDER_CONSTANT, value=[255, 255, 255] + ) + # 绘制 polygons 矩形 + for idx, polygon in enumerate(sorted_polygons): + x0, y0, x1, y1 = polygon[0], polygon[1], polygon[2], polygon[3] + x0 = round(x0) + y0 = round(y0) + x1 = round(x1) + y1 = round(y1) + cv2.rectangle(img, (x0, y0), (x1, y1), (0, 0, 255), 1) + # 增大字体大小和线宽 + font_scale = 0.9 # 原先是0.5 + thickness = 1 # 原先是1 + logic_point = logic_points[idx] + cv2.putText( + img, + f"row: {logic_point[0]}-{logic_point[1]}", + (x0 + 3, y0 + 8), + cv2.FONT_HERSHEY_PLAIN, + font_scale, + (0, 0, 255), + thickness, + ) + cv2.putText( + img, + f"col: {logic_point[2]}-{logic_point[3]}", + (x0 + 3, y0 + 18), + cv2.FONT_HERSHEY_PLAIN, + font_scale, + (0, 0, 255), + thickness, + ) + os.makedirs(os.path.dirname(output_path), exist_ok=True) + # 保存绘制后的图像 + self.save_img(output_path, img) + + @staticmethod + def draw_rectangle(img: np.ndarray, boxes: np.ndarray) -> np.ndarray: + img_copy = img.copy() + for box in boxes.astype(int): + x1, y1, x2, y2 = box + cv2.rectangle(img_copy, (x1, y1), (x2, y2), (255, 0, 0), 2) + return img_copy + + @staticmethod + def draw_polylines(img: np.ndarray, points) -> np.ndarray: + img_copy = img.copy() + for point in points.astype(int): + point = point.reshape(4, 2) + cv2.polylines(img_copy, [point.astype(int)], True, (255, 0, 0), 2) + return img_copy + + @staticmethod + def save_img(save_path: Union[str, Path], img: np.ndarray): + cv2.imwrite(str(save_path), img) + + @staticmethod + def save_html(save_path: Union[str, Path], html: str): + with open(save_path, "w", encoding="utf-8") as f: + f.write(html) diff --git a/helper/content_recognition/test.py b/helper/content_recognition/test.py new file mode 100644 index 0000000..1474bb8 --- /dev/null +++ b/helper/content_recognition/test.py @@ -0,0 +1,18 @@ +from magic_pdf.model.doc_analyze_by_custom_model import doc_analyze +from magic_pdf.data.read_api import read_local_images +from markdownify import markdownify as md +import re + + +# proc +## Create Dataset Instance +input_file = "/mnt/research/PaddleOCR/pdf2md_pipeline/s4_content_recognition/all_layouts/207.jpg" + +ds = read_local_images(input_file)[0] + +x = ds.apply(doc_analyze, ocr=True) +x = x.pipe_ocr_mode(None) +html = x.get_markdown(None) +content = md(html) +content = re.sub(r'\\([#*_`])', r'\1', content) +print(content) diff --git a/helper/content_recognition/utils.py b/helper/content_recognition/utils.py new file mode 100644 index 0000000..7e9cb03 --- /dev/null +++ b/helper/content_recognition/utils.py @@ -0,0 +1,213 @@ +import os +import tempfile +import cv2 +import numpy as np +from paddleocr import PaddleOCR +from marker.converters.table import TableConverter +from marker.models import create_model_dict +from marker.output import text_from_rendered +from .rapid_table_pipeline.main import table2md_pipeline +from magic_pdf.model.doc_analyze_by_custom_model import doc_analyze +from magic_pdf.data.read_api import read_local_images +from markdownify import markdownify as md +import re + + +def scanning_document_classify(image): + # 判断是否是扫描件 + + # 将图像从BGR颜色空间转换到HSV颜色空间 + hsv = cv2.cvtColor(image, cv2.COLOR_BGR2HSV) + + # 定义红色的HSV范围 + lower_red1 = np.array([0, 70, 50]) + upper_red1 = np.array([10, 255, 255]) + lower_red2 = np.array([170, 70, 50]) + upper_red2 = np.array([180, 255, 255]) + + # 创建两个掩码,一个用于低色调的红色,一个用于高色调的红色 + mask1 = cv2.inRange(hsv, lower_red1, upper_red1) + mask2 = cv2.inRange(hsv, lower_red2, upper_red2) + + # 将两个掩码合并 + mask = cv2.bitwise_or(mask1, mask2) + + # 计算红色区域的非零像素数量 + non_zero_pixels = cv2.countNonZero(mask) + return 1 < non_zero_pixels < 1000 + + +def remove_watermark(image): + # 去除红色印章 + _, _, r_channel = cv2.split(image) + r_channel[r_channel > 210] = 255 + r_channel = cv2.cvtColor(r_channel, cv2.COLOR_GRAY2BGR) + return r_channel + + +def html2md(html_content): + md_content = md(html_content) + md_content = re.sub(r'\\([#*_`])', r'\1', md_content) + return md_content + + +def markdown_rec(image): + # TODO 可以传入文件夹 + image_path = f'{tempfile.mktemp()}.jpg' + cv2.imwrite(image_path, image) + + try: + ds = read_local_images(image_path)[0] + x = ds.apply(doc_analyze, ocr=True) + x = x.pipe_ocr_mode(None) + html = x.get_markdown(None) + finally: + os.remove(image_path) + return html2md(html) + + +ocr = PaddleOCR(lang='ch') # need to run only once to download and load model into memory + + +def text_rec(image): + result = ocr.ocr(image, cls=True) + output = [] + for idx in range(len(result)): + res = result[idx] + if not res: + continue + for line in res: + if not line: + continue + output.append(line[1][0]) + return output + + +def table_rec(image): + return table2md_pipeline(image) + + +table_converter = TableConverter(artifact_dict=create_model_dict()) + + +def scanning_document_rec(image): + # TODO 内部的ocr可以替换为paddleocr以提升文字识别精度 + image_path = f'{tempfile.mktemp()}.jpg' + cv2.imwrite(image_path, image) + + try: + no_watermark_image = remove_watermark(cv2.imread(image_path)) + prefix, suffix = image_path.split('.') + new_image_path = f'{prefix}_remove_watermark.{suffix}' + cv2.imwrite(new_image_path, no_watermark_image) + + rendered = table_converter(new_image_path) + text, _, _ = text_from_rendered(rendered) + finally: + os.remove(image_path) + return text, no_watermark_image + + +def compute_box_distance(box1, box2): + x11, y11, x12, y12 = box1 + x21, y21, x22, y22 = box2 + + # 计算水平和垂直方向的重叠量 + x_overlap = max(0, min(x12, x22) - max(x11, x21)) + y_overlap = max(0, min(y12, y22) - max(y11, y21)) + + # 如果有重叠(x和y都重叠),返回负的重叠深度(取 min 表示最小穿透) + if x_overlap > 0 and y_overlap > 0: + return -min(x_overlap, y_overlap) + + distances = [] + + # 如果 x 方向有投影重叠,计算上下边的距离 + if x12 > x21 and x11 < x22: + dist_top = y21 - y12 # box1下边到box2上边 + dist_bottom = y11 - y22 # box1上边到box2下边 + if dist_top > 0: + distances.append(dist_top) + if dist_bottom > 0: + distances.append(dist_bottom) + + # 如果 y 方向有投影重叠,计算左右边的距离 + if y12 > y21 and y11 < y22: + dist_left = x11 - x22 # box1左边到box2右边 + dist_right = x21 - x12 # box1右边到box2左边 + if dist_left > 0: + distances.append(dist_left) + if dist_right > 0: + distances.append(dist_right) + + # 如果有合法的距离,返回最小值,否则说明边无法对齐,返回 None + return min(distances) if distances else None + + +def assign_tables_to_titles(layout_results, max_distance=200): + tables = [_ for _ in layout_results if _.clsid == 4] + titles = [_ for _ in layout_results if _.clsid == 5] + + table_to_title = {} + title_to_table = {} + + changed = True + while changed: + changed = False + for title in titles: + title_id = id(title) + + best_table = None + min_dist = float('inf') + + for table in tables: + table_id = id(table) + + dist = compute_box_distance(title.box, table.box) + if dist is None or dist > max_distance: + continue + + if dist < min_dist: + min_dist = dist + best_table = table + + if best_table is None: + continue + + table_id = id(best_table) + + current_table = title_to_table.get(title_id) + if current_table is best_table: + continue # 已是最优,无需更新 + + prev_title = table_to_title.get(table_id) + if prev_title: + prev_title_id = id(prev_title) + prev_dist = compute_box_distance(prev_title.box, best_table.box) + if prev_dist is not None and prev_dist <= min_dist: + continue # 原标题绑定得更近,跳过 + + # 解绑旧标题 + title_to_table.pop(prev_title_id, None) + + # 更新新绑定 + title_to_table[title_id] = best_table + table_to_title[table_id] = title + changed = True # 有更新 + + # 最终写回绑定结果 + for table in tables: + table_id = id(table) + title = table_to_title.get(table_id) + if title: + table.table_title = title.content + else: + table.table_title = None + + +if __name__ == '__main__': + # content = text_rec('/mnt/research/PaddleOCR/pdf2md_pipeline/s4_content_recognition/all_layouts/5.jpg') + # content = markdown_rec('/mnt/research/PaddleOCR/pdf2md_pipeline/s4_content_recognition/all_layouts/3.jpg') + # content = table_rec('/mnt/research/PaddleOCR/pdf2md_pipeline/s4_content_recognition/all_layouts/6.jpg') + content = scanning_document_rec('/mnt/research/PaddleOCR/pdf2md_pipeline/s4_content_recognition/all_layouts/103.jpg') + print(content) diff --git a/helper/db_helper.py b/helper/db_helper.py new file mode 100644 index 0000000..fe0242a --- /dev/null +++ b/helper/db_helper.py @@ -0,0 +1,128 @@ +import psycopg2 +from psycopg2 import OperationalError, extras +from loguru import logger +import os +import traceback + + +# 创建连接(你只需要创建一次,然后在服务中复用) +def create_connection(): + try: + conn = psycopg2.connect( + dbname=os.environ['POSTGRESQL_DATABASE'], + user=os.environ['POSTGRESQL_USERNAME'], + password=os.environ['POSTGRESQL_PASSWORD'], + host=os.environ['POSTGRESQL_HOST'], + port=os.environ['POSTGRESQL_PORT'] + ) + conn.autocommit = False + with conn.cursor() as cur: + cur.execute("SET TIME ZONE 'Asia/Shanghai';") + conn.commit() + return conn + except OperationalError as e: + logger.error(f"连接数据库失败: {e}") + return None + + +# 插入数据的函数 +def insert_data(conn, table, data_dict, pk_name='id'): + """ + 向指定表中插入数据 + :param conn: psycopg2 connection 对象 + :param table: 表名(字符串) + :param data_dict: 字典格式的数据,比如 {"column1": value1, "column2": value2} + """ + if conn is None: + logger.error("数据库连接无效") + return + + try: + with conn.cursor() as cur: + columns = data_dict.keys() + values = [data_dict[col] for col in columns] + + # 构造 SQL + insert_query = f""" + INSERT INTO {table} ({', '.join(columns)}) + VALUES ({', '.join(['%s'] * len(values))}) + RETURNING {pk_name} + """ + cur.execute(insert_query, values) + inserted_id = cur.fetchone()[0] + return inserted_id + + except Exception as e: + logger.error(f"插入数据失败: {e}") + raise e + + +def insert_multiple_data(conn, table, data_list, batch_size=100): + """ + 批量向指定表中插入数据(使用 execute_values) + :param conn: psycopg2 connection 对象 + :param table: 表名(字符串) + :param data_list: 包含多个字典的列表,每个字典代表一行数据,比如 [{"column1": value1, "column2": value2}, ...] + """ + if conn is None: + logger.error("数据库连接无效") + return + + try: + with conn.cursor() as cur: + columns = data_list[0].keys() + insert_query = f""" + INSERT INTO {table} ({', '.join(columns)}) + VALUES %s + """ + for i in range(0, len(data_list), batch_size): + batch = data_list[i:i + batch_size] + values = [tuple(d.values()) for d in batch] + extras.execute_values(cur, insert_query, values) + + except Exception as e: + logger.error(f"批量插入数据失败: {e}") + raise e + + +conn = create_connection() + + +def insert_pdf2md_table(pdf_path, pdf_name, process_status, start_time, end_time, rec_results): + data_dict = { + 'path': pdf_path, + 'filename': pdf_name, + 'process_status': process_status, + 'analysis_start_time': start_time, + 'analysis_end_time': end_time + } + try: + inserted_id = insert_data(conn, 'pdf_info', data_dict) + if process_status == 2: + data_list = [] + for i in range(len(rec_results)): + # 每一页 + page_no = i + 1 + for j in range(len(rec_results[i])): + # 每一个box + box = rec_results[i][j] + content = box.content + clsid = box.clsid + table_title = box.table_title + order = j + data_dict = { + 'layout_type': clsid, + 'content': content, + 'page_no': page_no, + 'pdf_id': inserted_id, + 'table_title': table_title, + 'display_order': order + } + data_list.append(data_dict) + insert_multiple_data(conn, 'pdf_analysis_output', data_list) + conn.commit() + return inserted_id + except Exception as e: + conn.rollback() + logger.error(f'operate database error!\n{traceback.format_exc()}') + raise e diff --git a/helper/image_helper.py b/helper/image_helper.py new file mode 100644 index 0000000..ec4c429 --- /dev/null +++ b/helper/image_helper.py @@ -0,0 +1,47 @@ +from typing import List +from pdf2image import convert_from_path +import os +import paddleclas +import cv2 +from .page_detection.utils import PageDetectionResult + + +paddle_clas_model = paddleclas.PaddleClas(model_name="text_image_orientation") + +def pdf2image(pdf_path, output_dir): + if not os.path.isdir(output_dir): + os.makedirs(output_dir) + images = convert_from_path(pdf_path) + for i, image in enumerate(images): + image.save(f'{output_dir}/{i + 1}.jpg') + + +def image_orient_cls(input_data): + return paddle_clas_model.predict(input_data) + + +def page_detection_visual(page_detection_result: PageDetectionResult): + img = cv2.imread(page_detection_result.image_path) + for box in page_detection_result.boxes: + pos = box.pos + clsid = box.clsid + confidence = box.confidence + if clsid == 0: + color = (0, 0, 0) + text = 'text' + elif clsid == 1: + color = (255, 0, 0) + text = 'title' + elif clsid == 2: + color = (0, 255, 0) + text = 'figure' + elif clsid == 4: + color = (0, 0, 255) + text = 'table' + if clsid == 5: + color = (255, 0, 255) + text = 'table caption' + text = f'{text} {confidence}' + img = cv2.rectangle(img, (int(pos[0]), int(pos[1])), (int(pos[2]), int(pos[3])), color, 2) + cv2.putText(img, text, (int(pos[0]), int(pos[1])), cv2.FONT_HERSHEY_TRIPLEX, 1, color, 2) + return img diff --git a/helper/page_detection/clrnet_postprocess.py b/helper/page_detection/clrnet_postprocess.py new file mode 100644 index 0000000..efaa345 --- /dev/null +++ b/helper/page_detection/clrnet_postprocess.py @@ -0,0 +1,262 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import numpy as np +import paddle +import paddle.nn as nn +from scipy.special import softmax +from scipy.interpolate import InterpolatedUnivariateSpline + + +def line_iou(pred, target, img_w, length=15, aligned=True): + ''' + Calculate the line iou value between predictions and targets + Args: + pred: lane predictions, shape: (num_pred, 72) + target: ground truth, shape: (num_target, 72) + img_w: image width + length: extended radius + aligned: True for iou loss calculation, False for pair-wise ious in assign + ''' + px1 = pred - length + px2 = pred + length + tx1 = target - length + tx2 = target + length + + if aligned: + invalid_mask = target + ovr = paddle.minimum(px2, tx2) - paddle.maximum(px1, tx1) + union = paddle.maximum(px2, tx2) - paddle.minimum(px1, tx1) + else: + num_pred = pred.shape[0] + invalid_mask = target.tile([num_pred, 1, 1]) + + ovr = (paddle.minimum(px2[:, None, :], tx2[None, ...]) - paddle.maximum( + px1[:, None, :], tx1[None, ...])) + union = (paddle.maximum(px2[:, None, :], tx2[None, ...]) - + paddle.minimum(px1[:, None, :], tx1[None, ...])) + + invalid_masks = (invalid_mask < 0) | (invalid_mask >= img_w) + + ovr[invalid_masks] = 0. + union[invalid_masks] = 0. + iou = ovr.sum(axis=-1) / (union.sum(axis=-1) + 1e-9) + return iou + + +class Lane: + def __init__(self, points=None, invalid_value=-2., metadata=None): + super(Lane, self).__init__() + self.curr_iter = 0 + self.points = points + self.invalid_value = invalid_value + self.function = InterpolatedUnivariateSpline( + points[:, 1], points[:, 0], k=min(3, len(points) - 1)) + self.min_y = points[:, 1].min() - 0.01 + self.max_y = points[:, 1].max() + 0.01 + self.metadata = metadata or {} + + def __repr__(self): + return '[Lane]\n' + str(self.points) + '\n[/Lane]' + + def __call__(self, lane_ys): + lane_xs = self.function(lane_ys) + + lane_xs[(lane_ys < self.min_y) | (lane_ys > self.max_y + )] = self.invalid_value + return lane_xs + + def to_array(self, sample_y_range, img_w, img_h): + self.sample_y = range(sample_y_range[0], sample_y_range[1], + sample_y_range[2]) + sample_y = self.sample_y + img_w, img_h = img_w, img_h + ys = np.array(sample_y) / float(img_h) + xs = self(ys) + valid_mask = (xs >= 0) & (xs < 1) + lane_xs = xs[valid_mask] * img_w + lane_ys = ys[valid_mask] * img_h + lane = np.concatenate( + (lane_xs.reshape(-1, 1), lane_ys.reshape(-1, 1)), axis=1) + return lane + + def __iter__(self): + return self + + def __next__(self): + if self.curr_iter < len(self.points): + self.curr_iter += 1 + return self.points[self.curr_iter - 1] + self.curr_iter = 0 + raise StopIteration + + +class CLRNetPostProcess(object): + """ + Args: + input_shape (int): network input image size + ori_shape (int): ori image shape of before padding + scale_factor (float): scale factor of ori image + enable_mkldnn (bool): whether to open MKLDNN + """ + + def __init__(self, img_w, ori_img_h, cut_height, conf_threshold, nms_thres, + max_lanes, num_points): + self.img_w = img_w + self.conf_threshold = conf_threshold + self.nms_thres = nms_thres + self.max_lanes = max_lanes + self.num_points = num_points + self.n_strips = num_points - 1 + self.n_offsets = num_points + self.ori_img_h = ori_img_h + self.cut_height = cut_height + + self.prior_ys = paddle.linspace( + start=1, stop=0, num=self.n_offsets).astype('float64') + + def predictions_to_pred(self, predictions): + """ + Convert predictions to internal Lane structure for evaluation. + """ + lanes = [] + for lane in predictions: + lane_xs = lane[6:].clone() + start = min( + max(0, int(round(lane[2].item() * self.n_strips))), + self.n_strips) + length = int(round(lane[5].item())) + end = start + length - 1 + end = min(end, len(self.prior_ys) - 1) + if start > 0: + mask = ((lane_xs[:start] >= 0.) & + (lane_xs[:start] <= 1.)).cpu().detach().numpy()[::-1] + mask = ~((mask.cumprod()[::-1]).astype(np.bool_)) + lane_xs[:start][mask] = -2 + if end < len(self.prior_ys) - 1: + lane_xs[end + 1:] = -2 + + lane_ys = self.prior_ys[lane_xs >= 0].clone() + lane_xs = lane_xs[lane_xs >= 0] + lane_xs = lane_xs.flip(axis=0).astype('float64') + lane_ys = lane_ys.flip(axis=0) + + lane_ys = (lane_ys * + (self.ori_img_h - self.cut_height) + self.cut_height + ) / self.ori_img_h + if len(lane_xs) <= 1: + continue + points = paddle.stack( + x=(lane_xs.reshape([-1, 1]), lane_ys.reshape([-1, 1])), + axis=1).squeeze(axis=2) + lane = Lane( + points=points.cpu().numpy(), + metadata={ + 'start_x': lane[3], + 'start_y': lane[2], + 'conf': lane[1] + }) + lanes.append(lane) + return lanes + + def lane_nms(self, predictions, scores, nms_overlap_thresh, top_k): + """ + NMS for lane detection. + predictions: paddle.Tensor [num_lanes,conf,y,x,lenght,72offsets] [12,77] + scores: paddle.Tensor [num_lanes] + nms_overlap_thresh: float + top_k: int + """ + # sort by scores to get idx + idx = scores.argsort(descending=True) + keep = [] + + condidates = predictions.clone() + condidates = condidates.index_select(idx) + + while len(condidates) > 0: + keep.append(idx[0]) + if len(keep) >= top_k or len(condidates) == 1: + break + + ious = [] + for i in range(1, len(condidates)): + ious.append(1 - line_iou( + condidates[i].unsqueeze(0), + condidates[0].unsqueeze(0), + img_w=self.img_w, + length=15)) + ious = paddle.to_tensor(ious) + + mask = ious <= nms_overlap_thresh + id = paddle.where(mask == False)[0] + + if id.shape[0] == 0: + break + condidates = condidates[1:].index_select(id) + idx = idx[1:].index_select(id) + keep = paddle.stack(keep) + + return keep + + def get_lanes(self, output, as_lanes=True): + """ + Convert model output to lanes. + """ + softmax = nn.Softmax(axis=1) + decoded = [] + + for predictions in output: + if len(predictions) == 0: + decoded.append([]) + continue + threshold = self.conf_threshold + scores = softmax(predictions[:, :2])[:, 1] + keep_inds = scores >= threshold + predictions = predictions[keep_inds] + scores = scores[keep_inds] + + if predictions.shape[0] == 0: + decoded.append([]) + continue + nms_predictions = predictions.detach().clone() + nms_predictions = paddle.concat( + x=[nms_predictions[..., :4], nms_predictions[..., 5:]], axis=-1) + + nms_predictions[..., 4] = nms_predictions[..., 4] * self.n_strips + nms_predictions[..., 5:] = nms_predictions[..., 5:] * ( + self.img_w - 1) + + keep = self.lane_nms( + nms_predictions[..., 5:], + scores, + nms_overlap_thresh=self.nms_thres, + top_k=self.max_lanes) + + predictions = predictions.index_select(keep) + + if predictions.shape[0] == 0: + decoded.append([]) + continue + predictions[:, 5] = paddle.round(predictions[:, 5] * self.n_strips) + if as_lanes: + pred = self.predictions_to_pred(predictions) + else: + pred = predictions + decoded.append(pred) + return decoded + + def __call__(self, lanes_list): + lanes = self.get_lanes(lanes_list) + return lanes diff --git a/helper/page_detection/keypoint_preprocess.py b/helper/page_detection/keypoint_preprocess.py new file mode 100644 index 0000000..b4e50e8 --- /dev/null +++ b/helper/page_detection/keypoint_preprocess.py @@ -0,0 +1,243 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +this code is based on https://github.com/open-mmlab/mmpose/mmpose/core/post_processing/post_transforms.py +""" +import cv2 +import numpy as np + + +class EvalAffine(object): + def __init__(self, size, stride=64): + super(EvalAffine, self).__init__() + self.size = size + self.stride = stride + + def __call__(self, image, im_info): + s = self.size + h, w, _ = image.shape + trans, size_resized = get_affine_mat_kernel(h, w, s, inv=False) + image_resized = cv2.warpAffine(image, trans, size_resized) + return image_resized, im_info + + +def get_affine_mat_kernel(h, w, s, inv=False): + if w < h: + w_ = s + h_ = int(np.ceil((s / w * h) / 64.) * 64) + scale_w = w + scale_h = h_ / w_ * w + + else: + h_ = s + w_ = int(np.ceil((s / h * w) / 64.) * 64) + scale_h = h + scale_w = w_ / h_ * h + + center = np.array([np.round(w / 2.), np.round(h / 2.)]) + + size_resized = (w_, h_) + trans = get_affine_transform( + center, np.array([scale_w, scale_h]), 0, size_resized, inv=inv) + + return trans, size_resized + + +def get_affine_transform(center, + input_size, + rot, + output_size, + shift=(0., 0.), + inv=False): + """Get the affine transform matrix, given the center/scale/rot/output_size. + + Args: + center (np.ndarray[2, ]): Center of the bounding box (x, y). + scale (np.ndarray[2, ]): Scale of the bounding box + wrt [width, height]. + rot (float): Rotation angle (degree). + output_size (np.ndarray[2, ]): Size of the destination heatmaps. + shift (0-100%): Shift translation ratio wrt the width/height. + Default (0., 0.). + inv (bool): Option to inverse the affine transform direction. + (inv=False: src->dst or inv=True: dst->src) + + Returns: + np.ndarray: The transform matrix. + """ + assert len(center) == 2 + assert len(output_size) == 2 + assert len(shift) == 2 + if not isinstance(input_size, (np.ndarray, list)): + input_size = np.array([input_size, input_size], dtype=np.float32) + scale_tmp = input_size + + shift = np.array(shift) + src_w = scale_tmp[0] + dst_w = output_size[0] + dst_h = output_size[1] + + rot_rad = np.pi * rot / 180 + src_dir = rotate_point([0., src_w * -0.5], rot_rad) + dst_dir = np.array([0., dst_w * -0.5]) + + src = np.zeros((3, 2), dtype=np.float32) + src[0, :] = center + scale_tmp * shift + src[1, :] = center + src_dir + scale_tmp * shift + src[2, :] = _get_3rd_point(src[0, :], src[1, :]) + + dst = np.zeros((3, 2), dtype=np.float32) + dst[0, :] = [dst_w * 0.5, dst_h * 0.5] + dst[1, :] = np.array([dst_w * 0.5, dst_h * 0.5]) + dst_dir + dst[2, :] = _get_3rd_point(dst[0, :], dst[1, :]) + + if inv: + trans = cv2.getAffineTransform(np.float32(dst), np.float32(src)) + else: + trans = cv2.getAffineTransform(np.float32(src), np.float32(dst)) + + return trans + + +def get_warp_matrix(theta, size_input, size_dst, size_target): + """This code is based on + https://github.com/open-mmlab/mmpose/blob/master/mmpose/core/post_processing/post_transforms.py + + Calculate the transformation matrix under the constraint of unbiased. + Paper ref: Huang et al. The Devil is in the Details: Delving into Unbiased + Data Processing for Human Pose Estimation (CVPR 2020). + + Args: + theta (float): Rotation angle in degrees. + size_input (np.ndarray): Size of input image [w, h]. + size_dst (np.ndarray): Size of output image [w, h]. + size_target (np.ndarray): Size of ROI in input plane [w, h]. + + Returns: + matrix (np.ndarray): A matrix for transformation. + """ + theta = np.deg2rad(theta) + matrix = np.zeros((2, 3), dtype=np.float32) + scale_x = size_dst[0] / size_target[0] + scale_y = size_dst[1] / size_target[1] + matrix[0, 0] = np.cos(theta) * scale_x + matrix[0, 1] = -np.sin(theta) * scale_x + matrix[0, 2] = scale_x * ( + -0.5 * size_input[0] * np.cos(theta) + 0.5 * size_input[1] * + np.sin(theta) + 0.5 * size_target[0]) + matrix[1, 0] = np.sin(theta) * scale_y + matrix[1, 1] = np.cos(theta) * scale_y + matrix[1, 2] = scale_y * ( + -0.5 * size_input[0] * np.sin(theta) - 0.5 * size_input[1] * + np.cos(theta) + 0.5 * size_target[1]) + return matrix + + +def rotate_point(pt, angle_rad): + """Rotate a point by an angle. + + Args: + pt (list[float]): 2 dimensional point to be rotated + angle_rad (float): rotation angle by radian + + Returns: + list[float]: Rotated point. + """ + assert len(pt) == 2 + sn, cs = np.sin(angle_rad), np.cos(angle_rad) + new_x = pt[0] * cs - pt[1] * sn + new_y = pt[0] * sn + pt[1] * cs + rotated_pt = [new_x, new_y] + + return rotated_pt + + +def _get_3rd_point(a, b): + """To calculate the affine matrix, three pairs of points are required. This + function is used to get the 3rd point, given 2D points a & b. + + The 3rd point is defined by rotating vector `a - b` by 90 degrees + anticlockwise, using b as the rotation center. + + Args: + a (np.ndarray): point(x,y) + b (np.ndarray): point(x,y) + + Returns: + np.ndarray: The 3rd point. + """ + assert len(a) == 2 + assert len(b) == 2 + direction = a - b + third_pt = b + np.array([-direction[1], direction[0]], dtype=np.float32) + + return third_pt + + +class TopDownEvalAffine(object): + """apply affine transform to image and coords + + Args: + trainsize (list): [w, h], the standard size used to train + use_udp (bool): whether to use Unbiased Data Processing. + records(dict): the dict contained the image and coords + + Returns: + records (dict): contain the image and coords after tranformed + + """ + + def __init__(self, trainsize, use_udp=False): + self.trainsize = trainsize + self.use_udp = use_udp + + def __call__(self, image, im_info): + rot = 0 + imshape = im_info['im_shape'][::-1] + center = im_info['center'] if 'center' in im_info else imshape / 2. + scale = im_info['scale'] if 'scale' in im_info else imshape + if self.use_udp: + trans = get_warp_matrix( + rot, center * 2.0, + [self.trainsize[0] - 1.0, self.trainsize[1] - 1.0], scale) + image = cv2.warpAffine( + image, + trans, (int(self.trainsize[0]), int(self.trainsize[1])), + flags=cv2.INTER_LINEAR) + else: + trans = get_affine_transform(center, scale, rot, self.trainsize) + image = cv2.warpAffine( + image, + trans, (int(self.trainsize[0]), int(self.trainsize[1])), + flags=cv2.INTER_LINEAR) + + return image, im_info + + +def expand_crop(images, rect, expand_ratio=0.3): + imgh, imgw, c = images.shape + label, conf, xmin, ymin, xmax, ymax = [int(x) for x in rect.tolist()] + if label != 0: + return None, None, None + org_rect = [xmin, ymin, xmax, ymax] + h_half = (ymax - ymin) * (1 + expand_ratio) / 2. + w_half = (xmax - xmin) * (1 + expand_ratio) / 2. + if h_half > w_half * 4 / 3: + w_half = h_half * 0.75 + center = [(ymin + ymax) / 2., (xmin + xmax) / 2.] + ymin = max(0, int(center[0] - h_half)) + ymax = min(imgh - 1, int(center[0] + h_half)) + xmin = max(0, int(center[1] - w_half)) + xmax = min(imgw - 1, int(center[1] + w_half)) + return images[ymin:ymax, xmin:xmax, :], [xmin, ymin, xmax, ymax], org_rect diff --git a/helper/page_detection/main.py b/helper/page_detection/main.py new file mode 100644 index 0000000..fd23166 --- /dev/null +++ b/helper/page_detection/main.py @@ -0,0 +1,69 @@ +from typing import List +from .pdf_detection import Pipeline +from utils import non_max_suppression, merge_text_and_title_boxes, LayoutBox, PageDetectionResult +from tqdm import tqdm + + + +""" + 0 - Text + 1 - Title + 2 - Figure + 3 - Figure caption + 4 - Table + 5 - Table caption + 6 - Header + 7 - Footer + 8 - Reference + 9 - Equation +""" +pipeline = Pipeline('./models/PaddleDetection/inference_model/picodet_lcnet_x1_0_fgd_layout_cdla_infer') + +effective_labels = [0, 1, 2, 4, 5] +# nms优先级,索引越低优先级越低,box重叠时优先保留表格 +label_scores = [1, 5, 0, 2, 4] +expand_pixel = 10 + + +def layout_analysis(image_paths) -> List[PageDetectionResult]: + layout_analysis_results = [] + for image_path in tqdm(image_paths, '版面分析'): + page_detecion_outputs = pipeline(image_path) + layout_boxes = [] + for i in range(len(page_detecion_outputs)): + clsid, box, confidence = page_detecion_outputs[i] + if clsid in effective_labels: + layout_boxes.append(LayoutBox(clsid, box, confidence)) + page_detecion_outputs = PageDetectionResult(layout_boxes, image_path) + + scores = [] + poses = [] + for box in page_detecion_outputs.boxes: + # 相同的label重叠时,保留面积更大的 + area = (box.pos[3] - box.pos[1]) * (box.pos[2] - box.pos[0]) + area_score = area / 5000000 + scores.append(label_scores.index(box.clsid) + area_score) + poses.append(box.pos) + indices = non_max_suppression(poses, scores, 0.2) + _boxes = [] + for i in indices: + _boxes.append(page_detecion_outputs.boxes[i]) + page_detecion_outputs.boxes = _boxes + + for i in range(len(page_detecion_outputs.boxes) - 1, -1, -1): + box = page_detecion_outputs.boxes[i] + if box.clsid in (0, 5): + # 移除Table box和Figure box中的Table caption box和Text box (有些扫描件会被识别为Figure) + for _box in page_detecion_outputs.boxes: + if _box.clsid != 2 and _box.clsid != 4: + continue + if box.pos[0] > _box.pos[0] and box.pos[1] > _box.pos[1] and box.pos[2] < _box.pos[2] and box.pos[3] < _box.pos[3]: + page_detecion_outputs.boxes.remove(box) + + # 将text和title合并起来,便于转成markdown格式 + page_detecion_outputs.boxes = merge_text_and_title_boxes(page_detecion_outputs.boxes, (0, 1)) + # 对box进行排序 + page_detecion_outputs.boxes.sort(key=lambda x: (x.pos[1], x.pos[0])) + layout_analysis_results.append(page_detecion_outputs) + + return layout_analysis_results diff --git a/helper/page_detection/pdf_detection.py b/helper/page_detection/pdf_detection.py new file mode 100644 index 0000000..44b484a --- /dev/null +++ b/helper/page_detection/pdf_detection.py @@ -0,0 +1,918 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import yaml +import glob +import json +from pathlib import Path + +import cv2 +import numpy as np +import math +import paddle +from paddle.inference import Config +from paddle.inference import create_predictor + +import sys +# add deploy path of PaddleDetection to sys.path +parent_path = os.path.abspath(os.path.join(__file__, *(['..']))) +sys.path.insert(0, parent_path) + +from preprocess import preprocess, Resize, NormalizeImage, Permute, PadStride, LetterBoxResize, WarpAffine, Pad, decode_image, CULaneResize +from picodet_postprocess import PicoDetPostProcess +from clrnet_postprocess import CLRNetPostProcess +from visualize import visualize_box_mask, imshow_lanes +from utils import argsparser, Timer, multiclass_nms, coco_clsid2catid + +# Global dictionary +SUPPORT_MODELS = { + 'YOLO', 'PPYOLOE', 'RCNN', 'SSD', 'Face', 'FCOS', 'SOLOv2', 'TTFNet', + 'S2ANet', 'JDE', 'FairMOT', 'DeepSORT', 'GFL', 'PicoDet', 'CenterNet', + 'TOOD', 'RetinaNet', 'StrongBaseline', 'STGCN', 'YOLOX', 'YOLOF', 'PPHGNet', + 'PPLCNet', 'DETR', 'CenterTrack', 'CLRNet' +} + + + +class Detector(object): + """ + Args: + pred_config (object): config of model, defined by `Config(model_dir)` + model_dir (str): root path of model.pdiparams, model.pdmodel and infer_cfg.yml + device (str): Choose the device you want to run, it can be: CPU/GPU/XPU, default is CPU + run_mode (str): mode of running(paddle/trt_fp32/trt_fp16) + batch_size (int): size of pre batch in inference + trt_min_shape (int): min shape for dynamic shape in trt + trt_max_shape (int): max shape for dynamic shape in trt + trt_opt_shape (int): opt shape for dynamic shape in trt + trt_calib_mode (bool): If the model is produced by TRT offline quantitative + calibration, trt_calib_mode need to set True + cpu_threads (int): cpu threads + enable_mkldnn (bool): whether to open MKLDNN + enable_mkldnn_bfloat16 (bool): whether to turn on mkldnn bfloat16 + output_dir (str): The path of output + threshold (float): The threshold of score for visualization + delete_shuffle_pass (bool): whether to remove shuffle_channel_detect_pass in TensorRT. + Used by action model. + """ + + def __init__(self, + model_dir, + device='CPU', + run_mode='paddle', + batch_size=1, + trt_min_shape=1, + trt_max_shape=1280, + trt_opt_shape=640, + trt_calib_mode=False, + cpu_threads=1, + enable_mkldnn=False, + enable_mkldnn_bfloat16=False, + output_dir='output', + threshold=0.5, + delete_shuffle_pass=False, + use_fd_format=False): + self.pred_config = self.set_config( + model_dir, use_fd_format=use_fd_format) + self.predictor, self.config = load_predictor( + model_dir, + self.pred_config.arch, + run_mode=run_mode, + batch_size=batch_size, + min_subgraph_size=self.pred_config.min_subgraph_size, + device=device, + use_dynamic_shape=self.pred_config.use_dynamic_shape, + trt_min_shape=trt_min_shape, + trt_max_shape=trt_max_shape, + trt_opt_shape=trt_opt_shape, + trt_calib_mode=trt_calib_mode, + cpu_threads=cpu_threads, + enable_mkldnn=enable_mkldnn, + enable_mkldnn_bfloat16=enable_mkldnn_bfloat16, + delete_shuffle_pass=delete_shuffle_pass) + self.det_times = Timer() + self.batch_size = batch_size + self.output_dir = output_dir + self.threshold = threshold + self.device = device + + def set_config(self, model_dir, use_fd_format): + return PredictConfig(model_dir, use_fd_format=use_fd_format) + + def preprocess(self, image_list): + preprocess_ops = [] + for op_info in self.pred_config.preprocess_infos: + new_op_info = op_info.copy() + op_type = new_op_info.pop('type') + preprocess_ops.append(eval(op_type)(**new_op_info)) + + input_im_lst = [] + input_im_info_lst = [] + for im_path in image_list: + im, im_info = preprocess(im_path, preprocess_ops) + input_im_lst.append(im) + input_im_info_lst.append(im_info) + inputs = create_inputs(input_im_lst, input_im_info_lst) + input_names = self.predictor.get_input_names() + for i in range(len(input_names)): + input_tensor = self.predictor.get_input_handle(input_names[i]) + if input_names[i] == 'x': + input_tensor.copy_from_cpu(inputs['image']) + else: + input_tensor.copy_from_cpu(inputs[input_names[i]]) + + return inputs + + def postprocess(self, inputs, result): + # postprocess output of predictor + np_boxes_num = result['boxes_num'] + assert isinstance(np_boxes_num, np.ndarray), \ + '`np_boxes_num` should be a `numpy.ndarray`' + + result = {k: v for k, v in result.items() if v is not None} + return result + + def filter_box(self, result, threshold): + np_boxes_num = result['boxes_num'] + boxes = result['boxes'] + start_idx = 0 + filter_boxes = [] + filter_num = [] + for i in range(len(np_boxes_num)): + boxes_num = np_boxes_num[i] + boxes_i = boxes[start_idx:start_idx + boxes_num, :] + idx = boxes_i[:, 1] > threshold + filter_boxes_i = boxes_i[idx, :] + filter_boxes.append(filter_boxes_i) + filter_num.append(filter_boxes_i.shape[0]) + start_idx += boxes_num + boxes = np.concatenate(filter_boxes) + filter_num = np.array(filter_num) + filter_res = {'boxes': boxes, 'boxes_num': filter_num} + return filter_res + + def predict(self, repeats=1, run_benchmark=False): + ''' + Args: + repeats (int): repeats number for prediction + Returns: + result (dict): include 'boxes': np.ndarray: shape:[N,6], N: number of box, + matix element:[class, score, x_min, y_min, x_max, y_max] + MaskRCNN's result include 'masks': np.ndarray: + shape: [N, im_h, im_w] + ''' + # model prediction + np_boxes_num, np_boxes, np_masks = np.array([0]), None, None + + if run_benchmark: + for i in range(repeats): + self.predictor.run() + if self.device == 'GPU': + paddle.device.cuda.synchronize() + else: + paddle.device.synchronize(device=self.device.lower()) + + result = dict( + boxes=np_boxes, masks=np_masks, boxes_num=np_boxes_num) + return result + + for i in range(repeats): + self.predictor.run() + output_names = self.predictor.get_output_names() + boxes_tensor = self.predictor.get_output_handle(output_names[0]) + np_boxes = boxes_tensor.copy_to_cpu() + if len(output_names) == 1: + # some exported model can not get tensor 'bbox_num' + np_boxes_num = np.array([len(np_boxes)]) + else: + boxes_num = self.predictor.get_output_handle(output_names[1]) + np_boxes_num = boxes_num.copy_to_cpu() + if self.pred_config.mask: + masks_tensor = self.predictor.get_output_handle(output_names[2]) + np_masks = masks_tensor.copy_to_cpu() + result = dict(boxes=np_boxes, masks=np_masks, boxes_num=np_boxes_num) + return result + + def merge_batch_result(self, batch_result): + if len(batch_result) == 1: + return batch_result[0] + res_key = batch_result[0].keys() + results = {k: [] for k in res_key} + for res in batch_result: + for k, v in res.items(): + results[k].append(v) + for k, v in results.items(): + if k not in ['masks', 'segm']: + results[k] = np.concatenate(v) + return results + + def get_timer(self): + return self.det_times + + def predict_image(self, + image_list, + threshold=0.5, + visual=True): + batch_loop_cnt = math.ceil(float(len(image_list)) / self.batch_size) + results = [] + for i in range(batch_loop_cnt): + start_index = i * self.batch_size + end_index = min((i + 1) * self.batch_size, len(image_list)) + batch_image_list = image_list[start_index:end_index] + # preprocess + self.det_times.preprocess_time_s.start() + inputs = self.preprocess(batch_image_list) + self.det_times.preprocess_time_s.end() + + # model prediction + self.det_times.inference_time_s.start() + result = self.predict() + self.det_times.inference_time_s.end() + + # postprocess + self.det_times.postprocess_time_s.start() + result = self.postprocess(inputs, result) + self.det_times.postprocess_time_s.end() + self.det_times.img_num += len(batch_image_list) + + if visual: + visualize( + batch_image_list, + result, + self.pred_config.labels, + output_dir=self.output_dir, + threshold=self.threshold) + # TODO 在这里处理batch + results.append(result) + results = self.merge_batch_result(results) + boxes = results['boxes'] + expect_boxes = (boxes[:, 1] > threshold) & (boxes[:, 0] > -1) + boxes = boxes[expect_boxes, :] + output = [] + for dt in boxes: + clsid, box, confidence = int(dt[0]), dt[2:].tolist(), dt[1] + output.append((clsid, box, confidence)) + return output + + +class DetectorSOLOv2(Detector): + """ + Args: + model_dir (str): root path of model.pdiparams, model.pdmodel and infer_cfg.yml + device (str): Choose the device you want to run, it can be: CPU/GPU/XPU, default is CPU + run_mode (str): mode of running(paddle/trt_fp32/trt_fp16) + batch_size (int): size of pre batch in inference + trt_min_shape (int): min shape for dynamic shape in trt + trt_max_shape (int): max shape for dynamic shape in trt + trt_opt_shape (int): opt shape for dynamic shape in trt + trt_calib_mode (bool): If the model is produced by TRT offline quantitative + calibration, trt_calib_mode need to set True + cpu_threads (int): cpu threads + enable_mkldnn (bool): whether to open MKLDNN + enable_mkldnn_bfloat16 (bool): Whether to turn on mkldnn bfloat16 + output_dir (str): The path of output + threshold (float): The threshold of score for visualization + + """ + + def __init__(self, + model_dir, + device='CPU', + run_mode='paddle', + batch_size=1, + trt_min_shape=1, + trt_max_shape=1280, + trt_opt_shape=640, + trt_calib_mode=False, + cpu_threads=1, + enable_mkldnn=False, + enable_mkldnn_bfloat16=False, + output_dir='./', + threshold=0.5, + use_fd_format=False): + super(DetectorSOLOv2, self).__init__( + model_dir=model_dir, + device=device, + run_mode=run_mode, + batch_size=batch_size, + trt_min_shape=trt_min_shape, + trt_max_shape=trt_max_shape, + trt_opt_shape=trt_opt_shape, + trt_calib_mode=trt_calib_mode, + cpu_threads=cpu_threads, + enable_mkldnn=enable_mkldnn, + enable_mkldnn_bfloat16=enable_mkldnn_bfloat16, + output_dir=output_dir, + threshold=threshold, + use_fd_format=use_fd_format) + + def predict(self, repeats=1, run_benchmark=False): + ''' + Args: + repeats (int): repeat number for prediction + Returns: + result (dict): 'segm': np.ndarray,shape:[N, im_h, im_w] + 'cate_label': label of segm, shape:[N] + 'cate_score': confidence score of segm, shape:[N] + ''' + np_segms, np_label, np_score, np_boxes_num = None, None, None, np.array( + [0]) + + if run_benchmark: + for i in range(repeats): + self.predictor.run() + paddle.device.cuda.synchronize() + result = dict( + segm=np_segms, + label=np_label, + score=np_score, + boxes_num=np_boxes_num) + return result + + for i in range(repeats): + self.predictor.run() + output_names = self.predictor.get_output_names() + np_segms = self.predictor.get_output_handle(output_names[ + 0]).copy_to_cpu() + np_boxes_num = self.predictor.get_output_handle(output_names[ + 1]).copy_to_cpu() + np_label = self.predictor.get_output_handle(output_names[ + 2]).copy_to_cpu() + np_score = self.predictor.get_output_handle(output_names[ + 3]).copy_to_cpu() + + result = dict( + segm=np_segms, + label=np_label, + score=np_score, + boxes_num=np_boxes_num) + return result + + +class DetectorPicoDet(Detector): + """ + Args: + model_dir (str): root path of model.pdiparams, model.pdmodel and infer_cfg.yml + device (str): Choose the device you want to run, it can be: CPU/GPU/XPU, default is CPU + run_mode (str): mode of running(paddle/trt_fp32/trt_fp16) + batch_size (int): size of pre batch in inference + trt_min_shape (int): min shape for dynamic shape in trt + trt_max_shape (int): max shape for dynamic shape in trt + trt_opt_shape (int): opt shape for dynamic shape in trt + trt_calib_mode (bool): If the model is produced by TRT offline quantitative + calibration, trt_calib_mode need to set True + cpu_threads (int): cpu threads + enable_mkldnn (bool): whether to turn on MKLDNN + enable_mkldnn_bfloat16 (bool): whether to turn on MKLDNN_BFLOAT16 + """ + + def __init__(self, + model_dir, + device='CPU', + run_mode='paddle', + batch_size=1, + trt_min_shape=1, + trt_max_shape=1280, + trt_opt_shape=640, + trt_calib_mode=False, + cpu_threads=1, + enable_mkldnn=False, + enable_mkldnn_bfloat16=False, + output_dir='./', + threshold=0.5, + use_fd_format=False): + super(DetectorPicoDet, self).__init__( + model_dir=model_dir, + device=device, + run_mode=run_mode, + batch_size=batch_size, + trt_min_shape=trt_min_shape, + trt_max_shape=trt_max_shape, + trt_opt_shape=trt_opt_shape, + trt_calib_mode=trt_calib_mode, + cpu_threads=cpu_threads, + enable_mkldnn=enable_mkldnn, + enable_mkldnn_bfloat16=enable_mkldnn_bfloat16, + output_dir=output_dir, + threshold=threshold, + use_fd_format=use_fd_format) + + def postprocess(self, inputs, result): + # postprocess output of predictor + np_score_list = result['boxes'] + np_boxes_list = result['boxes_num'] + postprocessor = PicoDetPostProcess( + inputs['image'].shape[2:], + inputs['im_shape'], + inputs['scale_factor'], + strides=self.pred_config.fpn_stride, + nms_threshold=self.pred_config.nms['nms_threshold']) + np_boxes, np_boxes_num = postprocessor(np_score_list, np_boxes_list) + result = dict(boxes=np_boxes, boxes_num=np_boxes_num) + return result + + def predict(self, repeats=1, run_benchmark=False): + ''' + Args: + repeats (int): repeat number for prediction + Returns: + result (dict): include 'boxes': np.ndarray: shape:[N,6], N: number of box, + matix element:[class, score, x_min, y_min, x_max, y_max] + ''' + np_score_list, np_boxes_list = [], [] + + if run_benchmark: + for i in range(repeats): + self.predictor.run() + paddle.device.cuda.synchronize() + result = dict(boxes=np_score_list, boxes_num=np_boxes_list) + return result + + for i in range(repeats): + self.predictor.run() + np_score_list.clear() + np_boxes_list.clear() + output_names = self.predictor.get_output_names() + num_outs = int(len(output_names) / 2) + for out_idx in range(num_outs): + np_score_list.append( + self.predictor.get_output_handle(output_names[out_idx]) + .copy_to_cpu()) + np_boxes_list.append( + self.predictor.get_output_handle(output_names[ + out_idx + num_outs]).copy_to_cpu()) + result = dict(boxes=np_score_list, boxes_num=np_boxes_list) + return result + + +class DetectorCLRNet(Detector): + """ + Args: + model_dir (str): root path of model.pdiparams, model.pdmodel and infer_cfg.yml + device (str): Choose the device you want to run, it can be: CPU/GPU/XPU, default is CPU + run_mode (str): mode of running(paddle/trt_fp32/trt_fp16) + batch_size (int): size of pre batch in inference + trt_min_shape (int): min shape for dynamic shape in trt + trt_max_shape (int): max shape for dynamic shape in trt + trt_opt_shape (int): opt shape for dynamic shape in trt + trt_calib_mode (bool): If the model is produced by TRT offline quantitative + calibration, trt_calib_mode need to set True + cpu_threads (int): cpu threads + enable_mkldnn (bool): whether to turn on MKLDNN + enable_mkldnn_bfloat16 (bool): whether to turn on MKLDNN_BFLOAT16 + """ + + def __init__(self, + model_dir, + device='CPU', + run_mode='paddle', + batch_size=1, + trt_min_shape=1, + trt_max_shape=1280, + trt_opt_shape=640, + trt_calib_mode=False, + cpu_threads=1, + enable_mkldnn=False, + enable_mkldnn_bfloat16=False, + output_dir='./', + threshold=0.5, + use_fd_format=False): + super(DetectorCLRNet, self).__init__( + model_dir=model_dir, + device=device, + run_mode=run_mode, + batch_size=batch_size, + trt_min_shape=trt_min_shape, + trt_max_shape=trt_max_shape, + trt_opt_shape=trt_opt_shape, + trt_calib_mode=trt_calib_mode, + cpu_threads=cpu_threads, + enable_mkldnn=enable_mkldnn, + enable_mkldnn_bfloat16=enable_mkldnn_bfloat16, + output_dir=output_dir, + threshold=threshold, + use_fd_format=use_fd_format) + + deploy_file = os.path.join(model_dir, 'infer_cfg.yml') + with open(deploy_file) as f: + yml_conf = yaml.safe_load(f) + self.img_w = yml_conf['img_w'] + self.ori_img_h = yml_conf['ori_img_h'] + self.cut_height = yml_conf['cut_height'] + self.max_lanes = yml_conf['max_lanes'] + self.nms_thres = yml_conf['nms_thres'] + self.num_points = yml_conf['num_points'] + self.conf_threshold = yml_conf['conf_threshold'] + + def postprocess(self, inputs, result): + # postprocess output of predictor + lanes_list = result['lanes'] + postprocessor = CLRNetPostProcess( + img_w=self.img_w, + ori_img_h=self.ori_img_h, + cut_height=self.cut_height, + conf_threshold=self.conf_threshold, + nms_thres=self.nms_thres, + max_lanes=self.max_lanes, + num_points=self.num_points) + lanes = postprocessor(lanes_list) + result = dict(lanes=lanes) + return result + + def predict(self, repeats=1, run_benchmark=False): + ''' + Args: + repeats (int): repeat number for prediction + Returns: + result (dict): include 'boxes': np.ndarray: shape:[N,6], N: number of box, + matix element:[class, score, x_min, y_min, x_max, y_max] + ''' + lanes_list = [] + + if run_benchmark: + for i in range(repeats): + self.predictor.run() + paddle.device.cuda.synchronize() + result = dict(lanes=lanes_list) + return result + + for i in range(repeats): + # TODO: check the output of predictor + self.predictor.run() + lanes_list.clear() + output_names = self.predictor.get_output_names() + num_outs = int(len(output_names) / 2) + if num_outs == 0: + lanes_list.append([]) + for out_idx in range(num_outs): + lanes_list.append( + self.predictor.get_output_handle(output_names[out_idx]) + .copy_to_cpu()) + result = dict(lanes=lanes_list) + return result + + +def create_inputs(imgs, im_info): + """generate input for different model type + Args: + imgs (list(numpy)): list of images (np.ndarray) + im_info (list(dict)): list of image info + Returns: + inputs (dict): input of model + """ + inputs = {} + + im_shape = [] + scale_factor = [] + if len(imgs) == 1: + inputs['image'] = np.array((imgs[0], )).astype('float32') + inputs['im_shape'] = np.array( + (im_info[0]['im_shape'], )).astype('float32') + inputs['scale_factor'] = np.array( + (im_info[0]['scale_factor'], )).astype('float32') + return inputs + + for e in im_info: + im_shape.append(np.array((e['im_shape'], )).astype('float32')) + scale_factor.append(np.array((e['scale_factor'], )).astype('float32')) + + inputs['im_shape'] = np.concatenate(im_shape, axis=0) + inputs['scale_factor'] = np.concatenate(scale_factor, axis=0) + + imgs_shape = [[e.shape[1], e.shape[2]] for e in imgs] + max_shape_h = max([e[0] for e in imgs_shape]) + max_shape_w = max([e[1] for e in imgs_shape]) + padding_imgs = [] + for img in imgs: + im_c, im_h, im_w = img.shape[:] + padding_im = np.zeros( + (im_c, max_shape_h, max_shape_w), dtype=np.float32) + padding_im[:, :im_h, :im_w] = img + padding_imgs.append(padding_im) + inputs['image'] = np.stack(padding_imgs, axis=0) + return inputs + + +class PredictConfig(): + """set config of preprocess, postprocess and visualize + Args: + model_dir (str): root path of model.yml + """ + + def __init__(self, model_dir, use_fd_format=False): + # parsing Yaml config for Preprocess + fd_deploy_file = os.path.join(model_dir, 'inference.yml') + ppdet_deploy_file = os.path.join(model_dir, 'infer_cfg.yml') + if use_fd_format: + if not os.path.exists(fd_deploy_file) and os.path.exists( + ppdet_deploy_file): + raise RuntimeError( + "Non-FD format model detected. Please set `use_fd_format` to False." + ) + deploy_file = fd_deploy_file + else: + if not os.path.exists(ppdet_deploy_file) and os.path.exists( + fd_deploy_file): + raise RuntimeError( + "FD format model detected. Please set `use_fd_format` to False." + ) + deploy_file = ppdet_deploy_file + with open(deploy_file) as f: + yml_conf = yaml.safe_load(f) + self.check_model(yml_conf) + self.arch = yml_conf['arch'] + self.preprocess_infos = yml_conf['Preprocess'] + self.min_subgraph_size = yml_conf['min_subgraph_size'] + self.labels = yml_conf['label_list'] + self.mask = False + self.use_dynamic_shape = yml_conf['use_dynamic_shape'] + if 'mask' in yml_conf: + self.mask = yml_conf['mask'] + self.tracker = None + if 'tracker' in yml_conf: + self.tracker = yml_conf['tracker'] + if 'NMS' in yml_conf: + self.nms = yml_conf['NMS'] + if 'fpn_stride' in yml_conf: + self.fpn_stride = yml_conf['fpn_stride'] + if self.arch == 'RCNN' and yml_conf.get('export_onnx', False): + print( + 'The RCNN export model is used for ONNX and it only supports batch_size = 1' + ) + self.print_config() + + def check_model(self, yml_conf): + """ + Raises: + ValueError: loaded model not in supported model type + """ + for support_model in SUPPORT_MODELS: + if support_model in yml_conf['arch']: + return True + raise ValueError("Unsupported arch: {}, expect {}".format(yml_conf[ + 'arch'], SUPPORT_MODELS)) + + def print_config(self): + print('----------- Model Configuration -----------') + print('%s: %s' % ('Model Arch', self.arch)) + print('%s: ' % ('Transform Order')) + for op_info in self.preprocess_infos: + print('--%s: %s' % ('transform op', op_info['type'])) + print('--------------------------------------------') + + +def load_predictor(model_dir, + arch, + run_mode='paddle', + batch_size=1, + device='CPU', + min_subgraph_size=3, + use_dynamic_shape=False, + trt_min_shape=1, + trt_max_shape=1280, + trt_opt_shape=640, + trt_calib_mode=False, + cpu_threads=1, + enable_mkldnn=False, + enable_mkldnn_bfloat16=False, + delete_shuffle_pass=False): + """set AnalysisConfig, generate AnalysisPredictor + Args: + model_dir (str): root path of __model__ and __params__ + device (str): Choose the device you want to run, it can be: CPU/GPU/XPU, default is CPU + run_mode (str): mode of running(paddle/trt_fp32/trt_fp16/trt_int8) + use_dynamic_shape (bool): use dynamic shape or not + trt_min_shape (int): min shape for dynamic shape in trt + trt_max_shape (int): max shape for dynamic shape in trt + trt_opt_shape (int): opt shape for dynamic shape in trt + trt_calib_mode (bool): If the model is produced by TRT offline quantitative + calibration, trt_calib_mode need to set True + delete_shuffle_pass (bool): whether to remove shuffle_channel_detect_pass in TensorRT. + Used by action model. + Returns: + predictor (PaddlePredictor): AnalysisPredictor + Raises: + ValueError: predict by TensorRT need device == 'GPU'. + """ + if device != 'GPU' and run_mode != 'paddle': + raise ValueError( + "Predict by TensorRT mode: {}, expect device=='GPU', but device == {}" + .format(run_mode, device)) + + if paddle.__version__ >= '3.0.0' or paddle.__version__ == '0.0.0': + model_path = model_dir + model_prefix = 'model' + infer_param = os.path.join(model_dir, 'model.pdiparams') + if not os.path.exists(infer_param): + if paddle.framework.use_pir_api(): + infer_model = os.path.join(model_dir, 'inference.pdmodel') + else: + infer_model = os.path.join(model_dir, 'inference.json') + if not os.path.exists(infer_model): + raise ValueError( + "Cannot find any inference model in dir: {}.".format(model_dir)) + + config = Config(model_path, model_prefix) + + else: + infer_model = os.path.join(model_dir, 'model.pdmodel') + infer_params = os.path.join(model_dir, 'model.pdiparams') + if not os.path.exists(infer_model): + infer_model = os.path.join(model_dir, 'inference.pdmodel') + infer_params = os.path.join(model_dir, 'inference.pdiparams') + if not os.path.exists(infer_model): + raise ValueError( + "Cannot find any inference model in dir: {},".format(model_dir)) + config = Config(infer_model, infer_params) + + if device == 'GPU': + # initial GPU memory(M), device ID + config.enable_use_gpu(200, 0) + # optimize graph and fuse op + config.switch_ir_optim(True) + else: + config.disable_gpu() + config.set_cpu_math_library_num_threads(cpu_threads) + if enable_mkldnn: + try: + # cache 10 different shapes for mkldnn to avoid memory leak + config.set_mkldnn_cache_capacity(10) + config.enable_mkldnn() + if enable_mkldnn_bfloat16: + config.enable_mkldnn_bfloat16() + except Exception as e: + print( + "The current environment does not support `mkldnn`, so disable mkldnn." + ) + pass + + precision_map = { + 'trt_int8': Config.Precision.Int8, + 'trt_fp32': Config.Precision.Float32, + 'trt_fp16': Config.Precision.Half + } + if run_mode in precision_map.keys(): + config.enable_tensorrt_engine( + workspace_size=(1 << 25) * batch_size, + max_batch_size=batch_size, + min_subgraph_size=min_subgraph_size, + precision_mode=precision_map[run_mode], + use_static=False, + use_calib_mode=trt_calib_mode) + if FLAGS.collect_trt_shape_info: + config.collect_shape_range_info(FLAGS.tuned_trt_shape_file) + elif os.path.exists(FLAGS.tuned_trt_shape_file): + print(f'Use dynamic shape file: ' + f'{FLAGS.tuned_trt_shape_file} for TRT...') + config.enable_tuned_tensorrt_dynamic_shape( + FLAGS.tuned_trt_shape_file, True) + + if use_dynamic_shape: + min_input_shape = { + 'image': [batch_size, 3, trt_min_shape, trt_min_shape], + 'scale_factor': [batch_size, 2] + } + max_input_shape = { + 'image': [batch_size, 3, trt_max_shape, trt_max_shape], + 'scale_factor': [batch_size, 2] + } + opt_input_shape = { + 'image': [batch_size, 3, trt_opt_shape, trt_opt_shape], + 'scale_factor': [batch_size, 2] + } + config.set_trt_dynamic_shape_info(min_input_shape, max_input_shape, + opt_input_shape) + print('trt set dynamic shape done!') + + # disable print log when predict + config.disable_glog_info() + # enable shared memory + config.enable_memory_optim() + # disable feed, fetch OP, needed by zero_copy_run + config.switch_use_feed_fetch_ops(False) + if delete_shuffle_pass: + config.delete_pass("shuffle_channel_detect_pass") + predictor = create_predictor(config) + return predictor, config + + +def visualize(image_list, result, labels, output_dir='output/', threshold=0.5): + # visualize the predict result + if 'lanes' in result: + for idx, image_file in enumerate(image_list): + lanes = result['lanes'][idx] + img = cv2.imread(image_file) + out_file = os.path.join(output_dir, os.path.basename(image_file)) + # hard code + lanes = [lane.to_array([], ) for lane in lanes] + imshow_lanes(img, lanes, out_file=out_file) + return + start_idx = 0 + for idx, image_file in enumerate(image_list): + im_bboxes_num = result['boxes_num'][idx] + im_results = {} + if 'boxes' in result: + im_results['boxes'] = result['boxes'][start_idx:start_idx + + im_bboxes_num, :] + if 'masks' in result: + im_results['masks'] = result['masks'][start_idx:start_idx + + im_bboxes_num, :] + if 'segm' in result: + im_results['segm'] = result['segm'][start_idx:start_idx + + im_bboxes_num, :] + if 'label' in result: + im_results['label'] = result['label'][start_idx:start_idx + + im_bboxes_num] + if 'score' in result: + im_results['score'] = result['score'][start_idx:start_idx + + im_bboxes_num] + + start_idx += im_bboxes_num + im = visualize_box_mask( + image_file, im_results, labels, threshold=threshold) + img_name = os.path.split(image_file)[-1] + if not os.path.exists(output_dir): + os.makedirs(output_dir) + out_path = os.path.join(output_dir, img_name) + im.save(out_path, quality=95) + print("save result to: " + out_path) + + +def print_arguments(args): + print('----------- Running Arguments -----------') + for arg, value in sorted(vars(args).items()): + print('%s: %s' % (arg, value)) + print('------------------------------------------') + + +class Pipeline(object): + + def __init__(self, model_dir): + if FLAGS.use_fd_format: + deploy_file = os.path.join(model_dir, 'inference.yml') + else: + deploy_file = os.path.join(model_dir, 'infer_cfg.yml') + with open(deploy_file) as f: + yml_conf = yaml.safe_load(f) + arch = yml_conf['arch'] + detector_func = 'Detector' + if arch == 'SOLOv2': + detector_func = 'DetectorSOLOv2' + elif arch == 'PicoDet': + detector_func = 'DetectorPicoDet' + elif arch == "CLRNet": + detector_func = 'DetectorCLRNet' + + self.detector = eval(detector_func)( + model_dir, + device=FLAGS.device, + run_mode=FLAGS.run_mode, + batch_size=FLAGS.batch_size, + trt_min_shape=FLAGS.trt_min_shape, + trt_max_shape=FLAGS.trt_max_shape, + trt_opt_shape=FLAGS.trt_opt_shape, + trt_calib_mode=FLAGS.trt_calib_mode, + cpu_threads=FLAGS.cpu_threads, + enable_mkldnn=FLAGS.enable_mkldnn, + enable_mkldnn_bfloat16=FLAGS.enable_mkldnn_bfloat16, + threshold=FLAGS.threshold, + output_dir=FLAGS.output_dir, + use_fd_format=FLAGS.use_fd_format) + + def __call__(self, image_path): + if FLAGS.image_dir is None and image_path is not None: + assert FLAGS.batch_size == 1, "batch_size should be 1, when image_file is not None" + if isinstance(image_path, str): + image_path = [image_path] + results = self.detector.predict_image( + image_path, + visual=FLAGS.save_images) + return results + + +paddle.enable_static() +parser = argsparser() +FLAGS = parser.parse_args() +print_arguments(FLAGS) +FLAGS.device = 'GPU' +FLAGS.save_images = False +FLAGS.device = FLAGS.device.upper() +assert FLAGS.device in ['CPU', 'GPU', 'XPU', 'NPU', 'MLU', 'GCU' + ], "device should be CPU, GPU, XPU, MLU, NPU or GCU" +assert not FLAGS.use_gpu, "use_gpu has been deprecated, please use --device" + +assert not ( + FLAGS.enable_mkldnn == False and FLAGS.enable_mkldnn_bfloat16 == True +), 'To enable mkldnn bfloat, please turn on both enable_mkldnn and enable_mkldnn_bfloat16' diff --git a/helper/page_detection/picodet_postprocess.py b/helper/page_detection/picodet_postprocess.py new file mode 100644 index 0000000..7df13f8 --- /dev/null +++ b/helper/page_detection/picodet_postprocess.py @@ -0,0 +1,227 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import numpy as np +from scipy.special import softmax + + +def hard_nms(box_scores, iou_threshold, top_k=-1, candidate_size=200): + """ + Args: + box_scores (N, 5): boxes in corner-form and probabilities. + iou_threshold: intersection over union threshold. + top_k: keep top_k results. If k <= 0, keep all the results. + candidate_size: only consider the candidates with the highest scores. + Returns: + picked: a list of indexes of the kept boxes + """ + scores = box_scores[:, -1] + boxes = box_scores[:, :-1] + picked = [] + indexes = np.argsort(scores) + indexes = indexes[-candidate_size:] + while len(indexes) > 0: + current = indexes[-1] + picked.append(current) + if 0 < top_k == len(picked) or len(indexes) == 1: + break + current_box = boxes[current, :] + indexes = indexes[:-1] + rest_boxes = boxes[indexes, :] + iou = iou_of( + rest_boxes, + np.expand_dims( + current_box, axis=0), ) + indexes = indexes[iou <= iou_threshold] + + return box_scores[picked, :] + + +def iou_of(boxes0, boxes1, eps=1e-5): + """Return intersection-over-union (Jaccard index) of boxes. + Args: + boxes0 (N, 4): ground truth boxes. + boxes1 (N or 1, 4): predicted boxes. + eps: a small number to avoid 0 as denominator. + Returns: + iou (N): IoU values. + """ + overlap_left_top = np.maximum(boxes0[..., :2], boxes1[..., :2]) + overlap_right_bottom = np.minimum(boxes0[..., 2:], boxes1[..., 2:]) + + overlap_area = area_of(overlap_left_top, overlap_right_bottom) + area0 = area_of(boxes0[..., :2], boxes0[..., 2:]) + area1 = area_of(boxes1[..., :2], boxes1[..., 2:]) + return overlap_area / (area0 + area1 - overlap_area + eps) + + +def area_of(left_top, right_bottom): + """Compute the areas of rectangles given two corners. + Args: + left_top (N, 2): left top corner. + right_bottom (N, 2): right bottom corner. + Returns: + area (N): return the area. + """ + hw = np.clip(right_bottom - left_top, 0.0, None) + return hw[..., 0] * hw[..., 1] + + +class PicoDetPostProcess(object): + """ + Args: + input_shape (int): network input image size + ori_shape (int): ori image shape of before padding + scale_factor (float): scale factor of ori image + enable_mkldnn (bool): whether to open MKLDNN + """ + + def __init__(self, + input_shape, + ori_shape, + scale_factor, + strides=[8, 16, 32, 64], + score_threshold=0.4, + nms_threshold=0.5, + nms_top_k=1000, + keep_top_k=100): + self.ori_shape = ori_shape + self.input_shape = input_shape + self.scale_factor = scale_factor + self.strides = strides + self.score_threshold = score_threshold + self.nms_threshold = nms_threshold + self.nms_top_k = nms_top_k + self.keep_top_k = keep_top_k + + def warp_boxes(self, boxes, ori_shape): + """Apply transform to boxes + """ + width, height = ori_shape[1], ori_shape[0] + n = len(boxes) + if n: + # warp points + xy = np.ones((n * 4, 3)) + xy[:, :2] = boxes[:, [0, 1, 2, 3, 0, 3, 2, 1]].reshape( + n * 4, 2) # x1y1, x2y2, x1y2, x2y1 + # xy = xy @ M.T # transform + xy = (xy[:, :2] / xy[:, 2:3]).reshape(n, 8) # rescale + # create new boxes + x = xy[:, [0, 2, 4, 6]] + y = xy[:, [1, 3, 5, 7]] + xy = np.concatenate( + (x.min(1), y.min(1), x.max(1), y.max(1))).reshape(4, n).T + # clip boxes + xy[:, [0, 2]] = xy[:, [0, 2]].clip(0, width) + xy[:, [1, 3]] = xy[:, [1, 3]].clip(0, height) + return xy.astype(np.float32) + else: + return boxes + + def __call__(self, scores, raw_boxes): + batch_size = raw_boxes[0].shape[0] + reg_max = int(raw_boxes[0].shape[-1] / 4 - 1) + out_boxes_num = [] + out_boxes_list = [] + for batch_id in range(batch_size): + # generate centers + decode_boxes = [] + select_scores = [] + for stride, box_distribute, score in zip(self.strides, raw_boxes, + scores): + box_distribute = box_distribute[batch_id] + score = score[batch_id] + # centers + fm_h = self.input_shape[0] / stride + fm_w = self.input_shape[1] / stride + h_range = np.arange(fm_h) + w_range = np.arange(fm_w) + ww, hh = np.meshgrid(w_range, h_range) + ct_row = (hh.flatten() + 0.5) * stride + ct_col = (ww.flatten() + 0.5) * stride + center = np.stack((ct_col, ct_row, ct_col, ct_row), axis=1) + + # box distribution to distance + reg_range = np.arange(reg_max + 1) + box_distance = box_distribute.reshape((-1, reg_max + 1)) + box_distance = softmax(box_distance, axis=1) + box_distance = box_distance * np.expand_dims(reg_range, axis=0) + box_distance = np.sum(box_distance, axis=1).reshape((-1, 4)) + box_distance = box_distance * stride + + # top K candidate + topk_idx = np.argsort(score.max(axis=1))[::-1] + topk_idx = topk_idx[:self.nms_top_k] + center = center[topk_idx] + score = score[topk_idx] + box_distance = box_distance[topk_idx] + + # decode box + decode_box = center + [-1, -1, 1, 1] * box_distance + + select_scores.append(score) + decode_boxes.append(decode_box) + + # nms + bboxes = np.concatenate(decode_boxes, axis=0) + confidences = np.concatenate(select_scores, axis=0) + picked_box_probs = [] + picked_labels = [] + for class_index in range(0, confidences.shape[1]): + probs = confidences[:, class_index] + mask = probs > self.score_threshold + probs = probs[mask] + if probs.shape[0] == 0: + continue + subset_boxes = bboxes[mask, :] + box_probs = np.concatenate( + [subset_boxes, probs.reshape(-1, 1)], axis=1) + box_probs = hard_nms( + box_probs, + iou_threshold=self.nms_threshold, + top_k=self.keep_top_k, ) + picked_box_probs.append(box_probs) + picked_labels.extend([class_index] * box_probs.shape[0]) + + if len(picked_box_probs) == 0: + out_boxes_list.append(np.empty((0, 4))) + out_boxes_num.append(0) + + else: + picked_box_probs = np.concatenate(picked_box_probs) + + # resize output boxes + picked_box_probs[:, :4] = self.warp_boxes( + picked_box_probs[:, :4], self.ori_shape[batch_id]) + im_scale = np.concatenate([ + self.scale_factor[batch_id][::-1], + self.scale_factor[batch_id][::-1] + ]) + picked_box_probs[:, :4] /= im_scale + # clas score box + out_boxes_list.append( + np.concatenate( + [ + np.expand_dims( + np.array(picked_labels), + axis=-1), np.expand_dims( + picked_box_probs[:, 4], axis=-1), + picked_box_probs[:, :4] + ], + axis=1)) + out_boxes_num.append(len(picked_labels)) + + out_boxes_list = np.concatenate(out_boxes_list, axis=0) + out_boxes_num = np.asarray(out_boxes_num).astype(np.int32) + return out_boxes_list, out_boxes_num diff --git a/helper/page_detection/preprocess.py b/helper/page_detection/preprocess.py new file mode 100644 index 0000000..1936d3e --- /dev/null +++ b/helper/page_detection/preprocess.py @@ -0,0 +1,549 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import cv2 +import numpy as np +import imgaug.augmenters as iaa +from keypoint_preprocess import get_affine_transform +from PIL import Image + + +def decode_image(im_file, im_info): + """read rgb image + Args: + im_file (str|np.ndarray): input can be image path or np.ndarray + im_info (dict): info of image + Returns: + im (np.ndarray): processed image (np.ndarray) + im_info (dict): info of processed image + """ + if isinstance(im_file, str): + with open(im_file, 'rb') as f: + im_read = f.read() + data = np.frombuffer(im_read, dtype='uint8') + im = cv2.imdecode(data, 1) # BGR mode, but need RGB mode + im = cv2.cvtColor(im, cv2.COLOR_BGR2RGB) + else: + im = im_file + im_info['im_shape'] = np.array(im.shape[:2], dtype=np.float32) + im_info['scale_factor'] = np.array([1., 1.], dtype=np.float32) + return im, im_info + + +class Resize_Mult32(object): + """resize image by target_size and max_size + Args: + target_size (int): the target size of image + keep_ratio (bool): whether keep_ratio or not, default true + interp (int): method of resize + """ + + def __init__(self, limit_side_len, limit_type, interp=cv2.INTER_LINEAR): + self.limit_side_len = limit_side_len + self.limit_type = limit_type + self.interp = interp + + def __call__(self, im, im_info): + """ + Args: + im (np.ndarray): image (np.ndarray) + im_info (dict): info of image + Returns: + im (np.ndarray): processed image (np.ndarray) + im_info (dict): info of processed image + """ + im_channel = im.shape[2] + im_scale_y, im_scale_x = self.generate_scale(im) + im = cv2.resize( + im, + None, + None, + fx=im_scale_x, + fy=im_scale_y, + interpolation=self.interp) + im_info['im_shape'] = np.array(im.shape[:2]).astype('float32') + im_info['scale_factor'] = np.array( + [im_scale_y, im_scale_x]).astype('float32') + return im, im_info + + def generate_scale(self, img): + """ + Args: + img (np.ndarray): image (np.ndarray) + Returns: + im_scale_x: the resize ratio of X + im_scale_y: the resize ratio of Y + """ + limit_side_len = self.limit_side_len + h, w, c = img.shape + + # limit the max side + if self.limit_type == 'max': + if h > w: + ratio = float(limit_side_len) / h + else: + ratio = float(limit_side_len) / w + elif self.limit_type == 'min': + if h < w: + ratio = float(limit_side_len) / h + else: + ratio = float(limit_side_len) / w + elif self.limit_type == 'resize_long': + ratio = float(limit_side_len) / max(h, w) + else: + raise Exception('not support limit type, image ') + resize_h = int(h * ratio) + resize_w = int(w * ratio) + + resize_h = max(int(round(resize_h / 32) * 32), 32) + resize_w = max(int(round(resize_w / 32) * 32), 32) + + im_scale_y = resize_h / float(h) + im_scale_x = resize_w / float(w) + return im_scale_y, im_scale_x + + +class Resize(object): + """resize image by target_size and max_size + Args: + target_size (int): the target size of image + keep_ratio (bool): whether keep_ratio or not, default true + interp (int): method of resize + """ + + def __init__(self, target_size, keep_ratio=True, interp=cv2.INTER_LINEAR): + if isinstance(target_size, int): + target_size = [target_size, target_size] + self.target_size = target_size + self.keep_ratio = keep_ratio + self.interp = interp + + def __call__(self, im, im_info): + """ + Args: + im (np.ndarray): image (np.ndarray) + im_info (dict): info of image + Returns: + im (np.ndarray): processed image (np.ndarray) + im_info (dict): info of processed image + """ + assert len(self.target_size) == 2 + assert self.target_size[0] > 0 and self.target_size[1] > 0 + im_channel = im.shape[2] + im_scale_y, im_scale_x = self.generate_scale(im) + im = cv2.resize( + im, + None, + None, + fx=im_scale_x, + fy=im_scale_y, + interpolation=self.interp) + im_info['im_shape'] = np.array(im.shape[:2]).astype('float32') + im_info['scale_factor'] = np.array( + [im_scale_y, im_scale_x]).astype('float32') + return im, im_info + + def generate_scale(self, im): + """ + Args: + im (np.ndarray): image (np.ndarray) + Returns: + im_scale_x: the resize ratio of X + im_scale_y: the resize ratio of Y + """ + origin_shape = im.shape[:2] + im_c = im.shape[2] + if self.keep_ratio: + im_size_min = np.min(origin_shape) + im_size_max = np.max(origin_shape) + target_size_min = np.min(self.target_size) + target_size_max = np.max(self.target_size) + im_scale = float(target_size_min) / float(im_size_min) + if np.round(im_scale * im_size_max) > target_size_max: + im_scale = float(target_size_max) / float(im_size_max) + im_scale_x = im_scale + im_scale_y = im_scale + else: + resize_h, resize_w = self.target_size + im_scale_y = resize_h / float(origin_shape[0]) + im_scale_x = resize_w / float(origin_shape[1]) + return im_scale_y, im_scale_x + + +class ShortSizeScale(object): + """ + Scale images by short size. + Args: + short_size(float | int): Short size of an image will be scaled to the short_size. + fixed_ratio(bool): Set whether to zoom according to a fixed ratio. default: True + do_round(bool): Whether to round up when calculating the zoom ratio. default: False + backend(str): Choose pillow or cv2 as the graphics processing backend. default: 'pillow' + """ + + def __init__(self, + short_size, + fixed_ratio=True, + keep_ratio=None, + do_round=False, + backend='pillow'): + self.short_size = short_size + assert (fixed_ratio and not keep_ratio) or ( + not fixed_ratio + ), "fixed_ratio and keep_ratio cannot be true at the same time" + self.fixed_ratio = fixed_ratio + self.keep_ratio = keep_ratio + self.do_round = do_round + + assert backend in [ + 'pillow', 'cv2' + ], "Scale's backend must be pillow or cv2, but get {backend}" + + self.backend = backend + + def __call__(self, img): + """ + Performs resize operations. + Args: + img (PIL.Image): a PIL.Image. + return: + resized_img: a PIL.Image after scaling. + """ + + result_img = None + + if isinstance(img, np.ndarray): + h, w, _ = img.shape + elif isinstance(img, Image.Image): + w, h = img.size + else: + raise NotImplementedError + + if w <= h: + ow = self.short_size + if self.fixed_ratio: # default is True + oh = int(self.short_size * 4.0 / 3.0) + elif not self.keep_ratio: # no + oh = self.short_size + else: + scale_factor = self.short_size / w + oh = int(h * float(scale_factor) + + 0.5) if self.do_round else int(h * self.short_size / w) + ow = int(w * float(scale_factor) + + 0.5) if self.do_round else int(w * self.short_size / h) + else: + oh = self.short_size + if self.fixed_ratio: + ow = int(self.short_size * 4.0 / 3.0) + elif not self.keep_ratio: # no + ow = self.short_size + else: + scale_factor = self.short_size / h + oh = int(h * float(scale_factor) + + 0.5) if self.do_round else int(h * self.short_size / w) + ow = int(w * float(scale_factor) + + 0.5) if self.do_round else int(w * self.short_size / h) + + if type(img) == np.ndarray: + img = Image.fromarray(img, mode='RGB') + + if self.backend == 'pillow': + result_img = img.resize((ow, oh), Image.BILINEAR) + elif self.backend == 'cv2' and (self.keep_ratio is not None): + result_img = cv2.resize( + img, (ow, oh), interpolation=cv2.INTER_LINEAR) + else: + result_img = Image.fromarray( + cv2.resize( + np.asarray(img), (ow, oh), interpolation=cv2.INTER_LINEAR)) + + return result_img + + +class NormalizeImage(object): + """normalize image + Args: + mean (list): im - mean + std (list): im / std + is_scale (bool): whether need im / 255 + norm_type (str): type in ['mean_std', 'none'] + """ + + def __init__(self, mean, std, is_scale=True, norm_type='mean_std'): + self.mean = mean + self.std = std + self.is_scale = is_scale + self.norm_type = norm_type + + def __call__(self, im, im_info): + """ + Args: + im (np.ndarray): image (np.ndarray) + im_info (dict): info of image + Returns: + im (np.ndarray): processed image (np.ndarray) + im_info (dict): info of processed image + """ + im = im.astype(np.float32, copy=False) + if self.is_scale: + scale = 1.0 / 255.0 + im *= scale + + if self.norm_type == 'mean_std': + mean = np.array(self.mean)[np.newaxis, np.newaxis, :] + std = np.array(self.std)[np.newaxis, np.newaxis, :] + im -= mean + im /= std + return im, im_info + + +class Permute(object): + """permute image + Args: + to_bgr (bool): whether convert RGB to BGR + channel_first (bool): whether convert HWC to CHW + """ + + def __init__(self, ): + super(Permute, self).__init__() + + def __call__(self, im, im_info): + """ + Args: + im (np.ndarray): image (np.ndarray) + im_info (dict): info of image + Returns: + im (np.ndarray): processed image (np.ndarray) + im_info (dict): info of processed image + """ + im = im.transpose((2, 0, 1)).copy() + return im, im_info + + +class PadStride(object): + """ padding image for model with FPN, instead PadBatch(pad_to_stride) in original config + Args: + stride (bool): model with FPN need image shape % stride == 0 + """ + + def __init__(self, stride=0): + self.coarsest_stride = stride + + def __call__(self, im, im_info): + """ + Args: + im (np.ndarray): image (np.ndarray) + im_info (dict): info of image + Returns: + im (np.ndarray): processed image (np.ndarray) + im_info (dict): info of processed image + """ + coarsest_stride = self.coarsest_stride + if coarsest_stride <= 0: + return im, im_info + im_c, im_h, im_w = im.shape + pad_h = int(np.ceil(float(im_h) / coarsest_stride) * coarsest_stride) + pad_w = int(np.ceil(float(im_w) / coarsest_stride) * coarsest_stride) + padding_im = np.zeros((im_c, pad_h, pad_w), dtype=np.float32) + padding_im[:, :im_h, :im_w] = im + return padding_im, im_info + + +class LetterBoxResize(object): + def __init__(self, target_size): + """ + Resize image to target size, convert normalized xywh to pixel xyxy + format ([x_center, y_center, width, height] -> [x0, y0, x1, y1]). + Args: + target_size (int|list): image target size. + """ + super(LetterBoxResize, self).__init__() + if isinstance(target_size, int): + target_size = [target_size, target_size] + self.target_size = target_size + + def letterbox(self, img, height, width, color=(127.5, 127.5, 127.5)): + # letterbox: resize a rectangular image to a padded rectangular + shape = img.shape[:2] # [height, width] + ratio_h = float(height) / shape[0] + ratio_w = float(width) / shape[1] + ratio = min(ratio_h, ratio_w) + new_shape = (round(shape[1] * ratio), + round(shape[0] * ratio)) # [width, height] + padw = (width - new_shape[0]) / 2 + padh = (height - new_shape[1]) / 2 + top, bottom = round(padh - 0.1), round(padh + 0.1) + left, right = round(padw - 0.1), round(padw + 0.1) + + img = cv2.resize( + img, new_shape, interpolation=cv2.INTER_AREA) # resized, no border + img = cv2.copyMakeBorder( + img, top, bottom, left, right, cv2.BORDER_CONSTANT, + value=color) # padded rectangular + return img, ratio, padw, padh + + def __call__(self, im, im_info): + """ + Args: + im (np.ndarray): image (np.ndarray) + im_info (dict): info of image + Returns: + im (np.ndarray): processed image (np.ndarray) + im_info (dict): info of processed image + """ + assert len(self.target_size) == 2 + assert self.target_size[0] > 0 and self.target_size[1] > 0 + height, width = self.target_size + h, w = im.shape[:2] + im, ratio, padw, padh = self.letterbox(im, height=height, width=width) + + new_shape = [round(h * ratio), round(w * ratio)] + im_info['im_shape'] = np.array(new_shape, dtype=np.float32) + im_info['scale_factor'] = np.array([ratio, ratio], dtype=np.float32) + return im, im_info + + +class Pad(object): + def __init__(self, size, fill_value=[114.0, 114.0, 114.0]): + """ + Pad image to a specified size. + Args: + size (list[int]): image target size + fill_value (list[float]): rgb value of pad area, default (114.0, 114.0, 114.0) + """ + super(Pad, self).__init__() + if isinstance(size, int): + size = [size, size] + self.size = size + self.fill_value = fill_value + + def __call__(self, im, im_info): + im_h, im_w = im.shape[:2] + h, w = self.size + if h == im_h and w == im_w: + im = im.astype(np.float32) + return im, im_info + + canvas = np.ones((h, w, 3), dtype=np.float32) + canvas *= np.array(self.fill_value, dtype=np.float32) + canvas[0:im_h, 0:im_w, :] = im.astype(np.float32) + im = canvas + return im, im_info + + +class WarpAffine(object): + """Warp affine the image + """ + + def __init__(self, + keep_res=False, + pad=31, + input_h=512, + input_w=512, + scale=0.4, + shift=0.1, + down_ratio=4): + self.keep_res = keep_res + self.pad = pad + self.input_h = input_h + self.input_w = input_w + self.scale = scale + self.shift = shift + self.down_ratio = down_ratio + + def __call__(self, im, im_info): + """ + Args: + im (np.ndarray): image (np.ndarray) + im_info (dict): info of image + Returns: + im (np.ndarray): processed image (np.ndarray) + im_info (dict): info of processed image + """ + img = cv2.cvtColor(im, cv2.COLOR_RGB2BGR) + + h, w = img.shape[:2] + + if self.keep_res: + # True in detection eval/infer + input_h = (h | self.pad) + 1 + input_w = (w | self.pad) + 1 + s = np.array([input_w, input_h], dtype=np.float32) + c = np.array([w // 2, h // 2], dtype=np.float32) + + else: + # False in centertrack eval_mot/eval_mot + s = max(h, w) * 1.0 + input_h, input_w = self.input_h, self.input_w + c = np.array([w / 2., h / 2.], dtype=np.float32) + + trans_input = get_affine_transform(c, s, 0, [input_w, input_h]) + img = cv2.resize(img, (w, h)) + inp = cv2.warpAffine( + img, trans_input, (input_w, input_h), flags=cv2.INTER_LINEAR) + + if not self.keep_res: + out_h = input_h // self.down_ratio + out_w = input_w // self.down_ratio + trans_output = get_affine_transform(c, s, 0, [out_w, out_h]) + + im_info.update({ + 'center': c, + 'scale': s, + 'out_height': out_h, + 'out_width': out_w, + 'inp_height': input_h, + 'inp_width': input_w, + 'trans_input': trans_input, + 'trans_output': trans_output, + }) + return inp, im_info + + +class CULaneResize(object): + def __init__(self, img_h, img_w, cut_height, prob=0.5): + super(CULaneResize, self).__init__() + self.img_h = img_h + self.img_w = img_w + self.cut_height = cut_height + self.prob = prob + + def __call__(self, im, im_info): + # cut + im = im[self.cut_height:, :, :] + # resize + transform = iaa.Sometimes(self.prob, + iaa.Resize({ + "height": self.img_h, + "width": self.img_w + })) + im = transform(image=im.copy().astype(np.uint8)) + + im = im.astype(np.float32) / 255. + # check transpose is need whether the func decode_image is equal to CULaneDataSet cv.imread + im = im.transpose(2, 0, 1) + + return im, im_info + + +def preprocess(im, preprocess_ops): + # process image by preprocess_ops + im_info = { + 'scale_factor': np.array( + [1., 1.], dtype=np.float32), + 'im_shape': None, + } + im, im_info = decode_image(im, im_info) + for operator in preprocess_ops: + im, im_info = operator(im, im_info) + return im, im_info diff --git a/helper/page_detection/start.sh b/helper/page_detection/start.sh new file mode 100755 index 0000000..014d60c --- /dev/null +++ b/helper/page_detection/start.sh @@ -0,0 +1,6 @@ +export FLAGS_enable_pir_api=0 + +python3 pdf_detection.py \ + --model_dir=/mnt/research/PaddleOCR/PaddleDetection/inference_model/picodet_lcnet_x1_0_fgd_layout_cdla_infer \ + --image_file=/mnt/research/PaddleOCR/demo-75-images/12.jpg \ + --device=GPU diff --git a/helper/page_detection/utils.py b/helper/page_detection/utils.py new file mode 100644 index 0000000..17184ef --- /dev/null +++ b/helper/page_detection/utils.py @@ -0,0 +1,629 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import time +import os +import ast +import argparse +from typing import List, Tuple +import numpy as np + + +class LayoutBox(object): + def __init__(self, clsid: int, pos: List[float], confidence: float): + self.clsid = clsid + self.pos = pos + self.confidence = confidence + + +class PageDetectionResult(object): + def __init__(self, boxes: List[LayoutBox], image_path: str): + self.boxes = boxes + self.image_path = image_path + + +def argsparser(): + parser = argparse.ArgumentParser(description=__doc__) + parser.add_argument( + "--image_dir", + type=str, + default=None, + help="Dir of image file, `image_file` has a higher priority.") + parser.add_argument( + "--batch_size", type=int, default=1, help="batch_size for inference.") + parser.add_argument( + "--video_file", + type=str, + default=None, + help="Path of video file, `video_file` or `camera_id` has a highest priority." + ) + parser.add_argument( + "--camera_id", + type=int, + default=-1, + help="device id of camera to predict.") + parser.add_argument( + "--threshold", type=float, default=0.5, help="Threshold of score.") + parser.add_argument( + "--output_dir", + type=str, + default="output", + help="Directory of output visualization files.") + parser.add_argument( + "--run_mode", + type=str, + default='paddle', + help="mode of running(paddle/trt_fp32/trt_fp16/trt_int8)") + parser.add_argument( + "--device", + type=str, + default='cpu', + help="Choose the device you want to run, it can be: CPU/GPU/XPU/NPU, default is CPU." + ) + parser.add_argument( + "--use_gpu", + type=ast.literal_eval, + default=False, + help="Deprecated, please use `--device`.") + parser.add_argument( + "--run_benchmark", + type=ast.literal_eval, + default=False, + help="Whether to predict a image_file repeatedly for benchmark") + parser.add_argument( + "--enable_mkldnn", + type=ast.literal_eval, + default=False, + help="Whether use mkldnn with CPU.") + parser.add_argument( + "--enable_mkldnn_bfloat16", + type=ast.literal_eval, + default=False, + help="Whether use mkldnn bfloat16 inference with CPU.") + parser.add_argument( + "--cpu_threads", type=int, default=1, help="Num of threads with CPU.") + parser.add_argument( + "--trt_min_shape", type=int, default=1, help="min_shape for TensorRT.") + parser.add_argument( + "--trt_max_shape", + type=int, + default=1280, + help="max_shape for TensorRT.") + parser.add_argument( + "--trt_opt_shape", + type=int, + default=640, + help="opt_shape for TensorRT.") + parser.add_argument( + "--trt_calib_mode", + type=bool, + default=False, + help="If the model is produced by TRT offline quantitative " + "calibration, trt_calib_mode need to set True.") + parser.add_argument( + '--save_images', + type=ast.literal_eval, + default=True, + help='Save visualization image results.') + parser.add_argument( + '--save_mot_txts', + action='store_true', + help='Save tracking results (txt).') + parser.add_argument( + '--save_mot_txt_per_img', + action='store_true', + help='Save tracking results (txt) for each image.') + parser.add_argument( + '--scaled', + type=bool, + default=False, + help="Whether coords after detector outputs are scaled, False in JDE YOLOv3 " + "True in general detector.") + parser.add_argument( + "--tracker_config", type=str, default=None, help=("tracker donfig")) + parser.add_argument( + "--reid_model_dir", + type=str, + default=None, + help=("Directory include:'model.pdiparams', 'model.pdmodel', " + "'infer_cfg.yml', created by tools/export_model.py.")) + parser.add_argument( + "--reid_batch_size", + type=int, + default=50, + help="max batch_size for reid model inference.") + parser.add_argument( + '--use_dark', + type=ast.literal_eval, + default=True, + help='whether to use darkpose to get better keypoint position predict ') + parser.add_argument( + "--action_file", + type=str, + default=None, + help="Path of input file for action recognition.") + parser.add_argument( + "--window_size", + type=int, + default=50, + help="Temporal size of skeleton feature for action recognition.") + parser.add_argument( + "--random_pad", + type=ast.literal_eval, + default=False, + help="Whether do random padding for action recognition.") + parser.add_argument( + "--save_results", + action='store_true', + default=False, + help="Whether save detection result to file using coco format") + parser.add_argument( + '--use_coco_category', + action='store_true', + default=False, + help='Whether to use the coco format dictionary `clsid2catid`') + parser.add_argument( + "--slice_infer", + action='store_true', + help="Whether to slice the image and merge the inference results for small object detection." + ) + parser.add_argument( + '--slice_size', + nargs='+', + type=int, + default=[640, 640], + help="Height of the sliced image.") + parser.add_argument( + "--overlap_ratio", + nargs='+', + type=float, + default=[0.25, 0.25], + help="Overlap height ratio of the sliced image.") + parser.add_argument( + "--combine_method", + type=str, + default='nms', + help="Combine method of the sliced images' detection results, choose in ['nms', 'nmm', 'concat']." + ) + parser.add_argument( + "--match_threshold", + type=float, + default=0.6, + help="Combine method matching threshold.") + parser.add_argument( + "--match_metric", + type=str, + default='ios', + help="Combine method matching metric, choose in ['iou', 'ios'].") + parser.add_argument( + "--collect_trt_shape_info", + action='store_true', + default=False, + help="Whether to collect dynamic shape before using tensorrt.") + parser.add_argument( + "--tuned_trt_shape_file", + type=str, + default="shape_range_info.pbtxt", + help="Path of a dynamic shape file for tensorrt.") + parser.add_argument("--use_fd_format", action="store_true") + parser.add_argument( + "--task_type", + type=str, + default='Detection', + help="How to save the coco result, it only work with save_results==True. Optional inputs are Rotate or Detection, default is Detection." + ) + return parser + + +class Times(object): + def __init__(self): + self.time = 0. + # start time + self.st = 0. + # end time + self.et = 0. + + def start(self): + self.st = time.time() + + def end(self, repeats=1, accumulative=True): + self.et = time.time() + if accumulative: + self.time += (self.et - self.st) / repeats + else: + self.time = (self.et - self.st) / repeats + + def reset(self): + self.time = 0. + self.st = 0. + self.et = 0. + + def value(self): + return round(self.time, 4) + + +class Timer(Times): + def __init__(self, with_tracker=False): + super(Timer, self).__init__() + self.with_tracker = with_tracker + self.preprocess_time_s = Times() + self.inference_time_s = Times() + self.postprocess_time_s = Times() + self.tracking_time_s = Times() + self.img_num = 0 + + def info(self, average=False): + pre_time = self.preprocess_time_s.value() + infer_time = self.inference_time_s.value() + post_time = self.postprocess_time_s.value() + track_time = self.tracking_time_s.value() + + total_time = pre_time + infer_time + post_time + if self.with_tracker: + total_time = total_time + track_time + total_time = round(total_time, 4) + print("------------------ Inference Time Info ----------------------") + print("total_time(ms): {}, img_num: {}".format(total_time * 1000, + self.img_num)) + preprocess_time = round(pre_time / max(1, self.img_num), + 4) if average else pre_time + postprocess_time = round(post_time / max(1, self.img_num), + 4) if average else post_time + inference_time = round(infer_time / max(1, self.img_num), + 4) if average else infer_time + tracking_time = round(track_time / max(1, self.img_num), + 4) if average else track_time + + average_latency = total_time / max(1, self.img_num) + qps = 0 + if total_time > 0: + qps = 1 / average_latency + print("average latency time(ms): {:.2f}, QPS: {:2f}".format( + average_latency * 1000, qps)) + if self.with_tracker: + print( + "preprocess_time(ms): {:.2f}, inference_time(ms): {:.2f}, postprocess_time(ms): {:.2f}, tracking_time(ms): {:.2f}". + format(preprocess_time * 1000, inference_time * 1000, + postprocess_time * 1000, tracking_time * 1000)) + else: + print( + "preprocess_time(ms): {:.2f}, inference_time(ms): {:.2f}, postprocess_time(ms): {:.2f}". + format(preprocess_time * 1000, inference_time * 1000, + postprocess_time * 1000)) + + def report(self, average=False): + dic = {} + pre_time = self.preprocess_time_s.value() + infer_time = self.inference_time_s.value() + post_time = self.postprocess_time_s.value() + track_time = self.tracking_time_s.value() + + dic['preprocess_time_s'] = round(pre_time / max(1, self.img_num), + 4) if average else pre_time + dic['inference_time_s'] = round(infer_time / max(1, self.img_num), + 4) if average else infer_time + dic['postprocess_time_s'] = round(post_time / max(1, self.img_num), + 4) if average else post_time + dic['img_num'] = self.img_num + total_time = pre_time + infer_time + post_time + if self.with_tracker: + dic['tracking_time_s'] = round(track_time / max(1, self.img_num), + 4) if average else track_time + total_time = total_time + track_time + dic['total_time_s'] = round(total_time, 4) + return dic + + + +def multiclass_nms(bboxs, num_classes, match_threshold=0.6, match_metric='iou'): + final_boxes = [] + for c in range(num_classes): + idxs = bboxs[:, 0] == c + if np.count_nonzero(idxs) == 0: continue + r = nms(bboxs[idxs, 1:], match_threshold, match_metric) + final_boxes.append(np.concatenate([np.full((r.shape[0], 1), c), r], 1)) + return final_boxes + + +def nms(dets, match_threshold=0.6, match_metric='iou'): + """ Apply NMS to avoid detecting too many overlapping bounding boxes. + Args: + dets: shape [N, 5], [score, x1, y1, x2, y2] + match_metric: 'iou' or 'ios' + match_threshold: overlap thresh for match metric. + """ + if dets.shape[0] == 0: + return dets[[], :] + scores = dets[:, 0] + x1 = dets[:, 1] + y1 = dets[:, 2] + x2 = dets[:, 3] + y2 = dets[:, 4] + areas = (x2 - x1 + 1) * (y2 - y1 + 1) + order = scores.argsort()[::-1] + + ndets = dets.shape[0] + suppressed = np.zeros((ndets), dtype=np.int32) + + for _i in range(ndets): + i = order[_i] + if suppressed[i] == 1: + continue + ix1 = x1[i] + iy1 = y1[i] + ix2 = x2[i] + iy2 = y2[i] + iarea = areas[i] + for _j in range(_i + 1, ndets): + j = order[_j] + if suppressed[j] == 1: + continue + xx1 = max(ix1, x1[j]) + yy1 = max(iy1, y1[j]) + xx2 = min(ix2, x2[j]) + yy2 = min(iy2, y2[j]) + w = max(0.0, xx2 - xx1 + 1) + h = max(0.0, yy2 - yy1 + 1) + inter = w * h + if match_metric == 'iou': + union = iarea + areas[j] - inter + match_value = inter / union + elif match_metric == 'ios': + smaller = min(iarea, areas[j]) + match_value = inter / smaller + else: + raise ValueError() + if match_value >= match_threshold: + suppressed[j] = 1 + keep = np.where(suppressed == 0)[0] + dets = dets[keep, :] + return dets + + +coco_clsid2catid = { + 0: 1, + 1: 2, + 2: 3, + 3: 4, + 4: 5, + 5: 6, + 6: 7, + 7: 8, + 8: 9, + 9: 10, + 10: 11, + 11: 13, + 12: 14, + 13: 15, + 14: 16, + 15: 17, + 16: 18, + 17: 19, + 18: 20, + 19: 21, + 20: 22, + 21: 23, + 22: 24, + 23: 25, + 24: 27, + 25: 28, + 26: 31, + 27: 32, + 28: 33, + 29: 34, + 30: 35, + 31: 36, + 32: 37, + 33: 38, + 34: 39, + 35: 40, + 36: 41, + 37: 42, + 38: 43, + 39: 44, + 40: 46, + 41: 47, + 42: 48, + 43: 49, + 44: 50, + 45: 51, + 46: 52, + 47: 53, + 48: 54, + 49: 55, + 50: 56, + 51: 57, + 52: 58, + 53: 59, + 54: 60, + 55: 61, + 56: 62, + 57: 63, + 58: 64, + 59: 65, + 60: 67, + 61: 70, + 62: 72, + 63: 73, + 64: 74, + 65: 75, + 66: 76, + 67: 77, + 68: 78, + 69: 79, + 70: 80, + 71: 81, + 72: 82, + 73: 84, + 74: 85, + 75: 86, + 76: 87, + 77: 88, + 78: 89, + 79: 90 +} + + +def gaussian_radius(bbox_size, min_overlap): + height, width = bbox_size + + a1 = 1 + b1 = (height + width) + c1 = width * height * (1 - min_overlap) / (1 + min_overlap) + sq1 = np.sqrt(b1**2 - 4 * a1 * c1) + radius1 = (b1 + sq1) / (2 * a1) + + a2 = 4 + b2 = 2 * (height + width) + c2 = (1 - min_overlap) * width * height + sq2 = np.sqrt(b2**2 - 4 * a2 * c2) + radius2 = (b2 + sq2) / 2 + + a3 = 4 * min_overlap + b3 = -2 * min_overlap * (height + width) + c3 = (min_overlap - 1) * width * height + sq3 = np.sqrt(b3**2 - 4 * a3 * c3) + radius3 = (b3 + sq3) / 2 + return min(radius1, radius2, radius3) + + +def gaussian2D(shape, sigma_x=1, sigma_y=1): + m, n = [(ss - 1.) / 2. for ss in shape] + y, x = np.ogrid[-m:m + 1, -n:n + 1] + + h = np.exp(-(x * x / (2 * sigma_x * sigma_x) + y * y / (2 * sigma_y * + sigma_y))) + h[h < np.finfo(h.dtype).eps * h.max()] = 0 + return h + + +def draw_umich_gaussian(heatmap, center, radius, k=1): + """ + draw_umich_gaussian, refer to https://github.com/xingyizhou/CenterNet/blob/master/src/lib/utils/image.py#L126 + """ + diameter = 2 * radius + 1 + gaussian = gaussian2D( + (diameter, diameter), sigma_x=diameter / 6, sigma_y=diameter / 6) + + x, y = int(center[0]), int(center[1]) + + height, width = heatmap.shape[0:2] + + left, right = min(x, radius), min(width - x, radius + 1) + top, bottom = min(y, radius), min(height - y, radius + 1) + + masked_heatmap = heatmap[y - top:y + bottom, x - left:x + right] + masked_gaussian = gaussian[radius - top:radius + bottom, radius - left: + radius + right] + if min(masked_gaussian.shape) > 0 and min(masked_heatmap.shape) > 0: + np.maximum(masked_heatmap, masked_gaussian * k, out=masked_heatmap) + return heatmap + + +def iou(box1, box2): + """计算两个框的 IoU(交并比)""" + x1 = max(box1[0], box2[0]) + y1 = max(box1[1], box2[1]) + x2 = min(box1[2], box2[2]) + y2 = min(box1[3], box2[3]) + + inter_area = max(0, x2 - x1) * max(0, y2 - y1) + box1_area = (box1[2] - box1[0]) * (box1[3] - box1[1]) + box2_area = (box2[2] - box2[0]) * (box2[3] - box2[1]) + + union_area = box1_area + box2_area - inter_area + return inter_area / union_area if union_area != 0 else 0 + +def non_max_suppression(boxes, scores, iou_threshold): + """非极大值抑制""" + if not boxes: + return [] + indices = np.argsort(scores)[::-1] + selected_boxes = [] + + while len(indices) > 0: + current = indices[0] + selected_boxes.append(current) + remaining = indices[1:] + + filtered_indices = [] + for i in remaining: + if iou(boxes[current], boxes[i]) <= iou_threshold: + filtered_indices.append(i) + + indices = np.array(filtered_indices) + + return selected_boxes + + +def is_box_inside(inner, outer): + """判断inner box是否完全在outer box内""" + return (outer[0] <= inner[0] and outer[1] <= inner[1] and + outer[2] >= inner[2] and outer[3] >= inner[3]) + + +def boxes_overlap(box1: List[float], box2: List[float]) -> bool: + """判断两个框是否有交集""" + x1_max = max(box1[0], box2[0]) + y1_max = max(box1[1], box2[1]) + x2_min = min(box1[2], box2[2]) + y2_min = min(box1[3], box2[3]) + return x1_max < x2_min and y1_max < y2_min + + +def merge_boxes(boxes: List[List[float]]) -> List[float]: + x1 = min(box[0] for box in boxes) + y1 = min(box[1] for box in boxes) + x2 = max(box[2] for box in boxes) + y2 = max(box[3] for box in boxes) + return [x1, y1, x2, y2] + + +def merge_text_and_title_boxes(data: List[LayoutBox], merged_labels: Tuple[int]) -> List[LayoutBox]: + text_title_boxes = [(i, box) for i, box in enumerate(data) if box.clsid in merged_labels] + other_boxes = [box.pos for box in data if box.clsid in (2, 4, 5)] + + text_title_boxes.sort(key=lambda x: x[1].pos[1]) # sort by y1 + + merged = [] + skip_indices = set() + i = 0 + while i < len(text_title_boxes): + if i in skip_indices: + i += 1 + continue + current_group = [text_title_boxes[i][1].pos] + group_confidences = [text_title_boxes[i][1].confidence] + j = i + 1 + while j < len(text_title_boxes): + candidate_box = text_title_boxes[j][1].pos + tentative_merge = merge_boxes(current_group + [candidate_box]) + has_intruder = any(boxes_overlap(other, tentative_merge) for other in other_boxes) + if has_intruder: + break + else: + current_group.append(candidate_box) + group_confidences.append(text_title_boxes[j][1].confidence) + skip_indices.add(j) + j += 1 + if len(current_group) > 1: + merged_box = LayoutBox(0, merge_boxes(current_group), max(group_confidences)) + merged.append(merged_box) + else: + idx = text_title_boxes[i][0] + merged.append(data[idx]) + i += 1 + + remaining = [data[i] for i in range(len(data)) if i not in skip_indices and data[i].clsid not in merged_labels] + return merged + remaining diff --git a/helper/page_detection/visualize.py b/helper/page_detection/visualize.py new file mode 100644 index 0000000..9e96c29 --- /dev/null +++ b/helper/page_detection/visualize.py @@ -0,0 +1,649 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import division + +import os +import cv2 +import math +import numpy as np +import PIL +from PIL import Image, ImageDraw, ImageFile +ImageFile.LOAD_TRUNCATED_IMAGES = True + +def imagedraw_textsize_c(draw, text): + if int(PIL.__version__.split('.')[0]) < 10: + tw, th = draw.textsize(text) + else: + left, top, right, bottom = draw.textbbox((0, 0), text) + tw, th = right - left, bottom - top + + return tw, th + + +def visualize_box_mask(im, results, labels, threshold=0.5): + """ + Args: + im (str/np.ndarray): path of image/np.ndarray read by cv2 + results (dict): include 'boxes': np.ndarray: shape:[N,6], N: number of box, + matix element:[class, score, x_min, y_min, x_max, y_max] + MaskRCNN's results include 'masks': np.ndarray: + shape:[N, im_h, im_w] + labels (list): labels:['class1', ..., 'classn'] + threshold (float): Threshold of score. + Returns: + im (PIL.Image.Image): visualized image + """ + if isinstance(im, str): + im = Image.open(im).convert('RGB') + elif isinstance(im, np.ndarray): + im = Image.fromarray(im) + if 'masks' in results and 'boxes' in results and len(results['boxes']) > 0: + im = draw_mask( + im, results['boxes'], results['masks'], labels, threshold=threshold) + if 'boxes' in results and len(results['boxes']) > 0: + im = draw_box(im, results['boxes'], labels, threshold=threshold) + if 'segm' in results: + im = draw_segm( + im, + results['segm'], + results['label'], + results['score'], + labels, + threshold=threshold) + return im + + +def get_color_map_list(num_classes): + """ + Args: + num_classes (int): number of class + Returns: + color_map (list): RGB color list + """ + color_map = num_classes * [0, 0, 0] + for i in range(0, num_classes): + j = 0 + lab = i + while lab: + color_map[i * 3] |= (((lab >> 0) & 1) << (7 - j)) + color_map[i * 3 + 1] |= (((lab >> 1) & 1) << (7 - j)) + color_map[i * 3 + 2] |= (((lab >> 2) & 1) << (7 - j)) + j += 1 + lab >>= 3 + color_map = [color_map[i:i + 3] for i in range(0, len(color_map), 3)] + return color_map + + +def draw_mask(im, np_boxes, np_masks, labels, threshold=0.5): + """ + Args: + im (PIL.Image.Image): PIL image + np_boxes (np.ndarray): shape:[N,6], N: number of box, + matix element:[class, score, x_min, y_min, x_max, y_max] + np_masks (np.ndarray): shape:[N, im_h, im_w] + labels (list): labels:['class1', ..., 'classn'] + threshold (float): threshold of mask + Returns: + im (PIL.Image.Image): visualized image + """ + color_list = get_color_map_list(len(labels)) + w_ratio = 0.4 + alpha = 0.7 + im = np.array(im).astype('float32') + clsid2color = {} + expect_boxes = (np_boxes[:, 1] > threshold) & (np_boxes[:, 0] > -1) + np_boxes = np_boxes[expect_boxes, :] + np_masks = np_masks[expect_boxes, :, :] + im_h, im_w = im.shape[:2] + np_masks = np_masks[:, :im_h, :im_w] + for i in range(len(np_masks)): + clsid, score = int(np_boxes[i][0]), np_boxes[i][1] + mask = np_masks[i] + if clsid not in clsid2color: + clsid2color[clsid] = color_list[clsid] + color_mask = clsid2color[clsid] + for c in range(3): + color_mask[c] = color_mask[c] * (1 - w_ratio) + w_ratio * 255 + idx = np.nonzero(mask) + color_mask = np.array(color_mask) + im[idx[0], idx[1], :] *= 1.0 - alpha + im[idx[0], idx[1], :] += alpha * color_mask + return Image.fromarray(im.astype('uint8')) + + +def draw_box(im, np_boxes, labels, threshold=0.5): + """ + Args: + im (PIL.Image.Image): PIL image + np_boxes (np.ndarray): shape:[N,6], N: number of box, + matix element:[class, score, x_min, y_min, x_max, y_max] + labels (list): labels:['class1', ..., 'classn'] + threshold (float): threshold of box + Returns: + im (PIL.Image.Image): visualized image + """ + draw_thickness = min(im.size) // 320 + draw = ImageDraw.Draw(im) + clsid2color = {} + color_list = get_color_map_list(len(labels)) + expect_boxes = (np_boxes[:, 1] > threshold) & (np_boxes[:, 0] > -1) + np_boxes = np_boxes[expect_boxes, :] + + for dt in np_boxes: + clsid, bbox, score = int(dt[0]), dt[2:], dt[1] + if clsid not in clsid2color: + clsid2color[clsid] = color_list[clsid] + color = tuple(clsid2color[clsid]) + + if len(bbox) == 4: + xmin, ymin, xmax, ymax = bbox + print('class_id:{:d}, confidence:{:.4f}, left_top:[{:.2f},{:.2f}],' + 'right_bottom:[{:.2f},{:.2f}]'.format( + int(clsid), score, xmin, ymin, xmax, ymax)) + # draw bbox + draw.line( + [(xmin, ymin), (xmin, ymax), (xmax, ymax), (xmax, ymin), + (xmin, ymin)], + width=draw_thickness, + fill=color) + elif len(bbox) == 8: + x1, y1, x2, y2, x3, y3, x4, y4 = bbox + draw.line( + [(x1, y1), (x2, y2), (x3, y3), (x4, y4), (x1, y1)], + width=2, + fill=color) + xmin = min(x1, x2, x3, x4) + ymin = min(y1, y2, y3, y4) + + # draw label + text = "{} {:.4f}".format(labels[clsid], score) + tw, th = imagedraw_textsize_c(draw, text) + draw.rectangle( + [(xmin + 1, ymin - th), (xmin + tw + 1, ymin)], fill=color) + draw.text((xmin + 1, ymin - th), text, fill=(255, 255, 255)) + return im + + +def draw_segm(im, + np_segms, + np_label, + np_score, + labels, + threshold=0.5, + alpha=0.7): + """ + Draw segmentation on image + """ + mask_color_id = 0 + w_ratio = .4 + color_list = get_color_map_list(len(labels)) + im = np.array(im).astype('float32') + clsid2color = {} + np_segms = np_segms.astype(np.uint8) + for i in range(np_segms.shape[0]): + mask, score, clsid = np_segms[i], np_score[i], np_label[i] + if score < threshold: + continue + + if clsid not in clsid2color: + clsid2color[clsid] = color_list[clsid] + color_mask = clsid2color[clsid] + for c in range(3): + color_mask[c] = color_mask[c] * (1 - w_ratio) + w_ratio * 255 + idx = np.nonzero(mask) + color_mask = np.array(color_mask) + idx0 = np.minimum(idx[0], im.shape[0] - 1) + idx1 = np.minimum(idx[1], im.shape[1] - 1) + im[idx0, idx1, :] *= 1.0 - alpha + im[idx0, idx1, :] += alpha * color_mask + sum_x = np.sum(mask, axis=0) + x = np.where(sum_x > 0.5)[0] + sum_y = np.sum(mask, axis=1) + y = np.where(sum_y > 0.5)[0] + x0, x1, y0, y1 = x[0], x[-1], y[0], y[-1] + cv2.rectangle(im, (x0, y0), (x1, y1), + tuple(color_mask.astype('int32').tolist()), 1) + bbox_text = '%s %.2f' % (labels[clsid], score) + t_size = cv2.getTextSize(bbox_text, 0, 0.3, thickness=1)[0] + cv2.rectangle(im, (x0, y0), (x0 + t_size[0], y0 - t_size[1] - 3), + tuple(color_mask.astype('int32').tolist()), -1) + cv2.putText( + im, + bbox_text, (x0, y0 - 2), + cv2.FONT_HERSHEY_SIMPLEX, + 0.3, (0, 0, 0), + 1, + lineType=cv2.LINE_AA) + return Image.fromarray(im.astype('uint8')) + + +def get_color(idx): + idx = idx * 3 + color = ((37 * idx) % 255, (17 * idx) % 255, (29 * idx) % 255) + return color + + +def visualize_pose(imgfile, + results, + visual_thresh=0.6, + save_name='pose.jpg', + save_dir='output', + returnimg=False, + ids=None): + try: + import matplotlib.pyplot as plt + import matplotlib + plt.switch_backend('agg') + except Exception as e: + print('Matplotlib not found, please install matplotlib.' + 'for example: `pip install matplotlib`.') + raise e + skeletons, scores = results['keypoint'] + skeletons = np.array(skeletons) + kpt_nums = 17 + if len(skeletons) > 0: + kpt_nums = skeletons.shape[1] + if kpt_nums == 17: #plot coco keypoint + EDGES = [(0, 1), (0, 2), (1, 3), (2, 4), (3, 5), (4, 6), (5, 7), (6, 8), + (7, 9), (8, 10), (5, 11), (6, 12), (11, 13), (12, 14), + (13, 15), (14, 16), (11, 12)] + else: #plot mpii keypoint + EDGES = [(0, 1), (1, 2), (3, 4), (4, 5), (2, 6), (3, 6), (6, 7), (7, 8), + (8, 9), (10, 11), (11, 12), (13, 14), (14, 15), (8, 12), + (8, 13)] + NUM_EDGES = len(EDGES) + + colors = [[255, 0, 0], [255, 85, 0], [255, 170, 0], [255, 255, 0], [170, 255, 0], [85, 255, 0], [0, 255, 0], \ + [0, 255, 85], [0, 255, 170], [0, 255, 255], [0, 170, 255], [0, 85, 255], [0, 0, 255], [85, 0, 255], \ + [170, 0, 255], [255, 0, 255], [255, 0, 170], [255, 0, 85]] + cmap = matplotlib.cm.get_cmap('hsv') + plt.figure() + + img = cv2.imread(imgfile) if type(imgfile) == str else imgfile + + color_set = results['colors'] if 'colors' in results else None + + if 'bbox' in results and ids is None: + bboxs = results['bbox'] + for j, rect in enumerate(bboxs): + xmin, ymin, xmax, ymax = rect + color = colors[0] if color_set is None else colors[color_set[j] % + len(colors)] + cv2.rectangle(img, (xmin, ymin), (xmax, ymax), color, 1) + + canvas = img.copy() + for i in range(kpt_nums): + for j in range(len(skeletons)): + if skeletons[j][i, 2] < visual_thresh: + continue + if ids is None: + color = colors[i] if color_set is None else colors[color_set[j] + % + len(colors)] + else: + color = get_color(ids[j]) + + cv2.circle( + canvas, + tuple(skeletons[j][i, 0:2].astype('int32')), + 2, + color, + thickness=-1) + + to_plot = cv2.addWeighted(img, 0.3, canvas, 0.7, 0) + fig = matplotlib.pyplot.gcf() + + stickwidth = 2 + + for i in range(NUM_EDGES): + for j in range(len(skeletons)): + edge = EDGES[i] + if skeletons[j][edge[0], 2] < visual_thresh or skeletons[j][edge[ + 1], 2] < visual_thresh: + continue + + cur_canvas = canvas.copy() + X = [skeletons[j][edge[0], 1], skeletons[j][edge[1], 1]] + Y = [skeletons[j][edge[0], 0], skeletons[j][edge[1], 0]] + mX = np.mean(X) + mY = np.mean(Y) + length = ((X[0] - X[1])**2 + (Y[0] - Y[1])**2)**0.5 + angle = math.degrees(math.atan2(X[0] - X[1], Y[0] - Y[1])) + polygon = cv2.ellipse2Poly((int(mY), int(mX)), + (int(length / 2), stickwidth), + int(angle), 0, 360, 1) + if ids is None: + color = colors[i] if color_set is None else colors[color_set[j] + % + len(colors)] + else: + color = get_color(ids[j]) + cv2.fillConvexPoly(cur_canvas, polygon, color) + canvas = cv2.addWeighted(canvas, 0.4, cur_canvas, 0.6, 0) + if returnimg: + return canvas + save_name = os.path.join( + save_dir, os.path.splitext(os.path.basename(imgfile))[0] + '_vis.jpg') + plt.imsave(save_name, canvas[:, :, ::-1]) + print("keypoint visualize image saved to: " + save_name) + plt.close() + + +def visualize_attr(im, results, boxes=None, is_mtmct=False): + if isinstance(im, str): + im = Image.open(im) + im = np.ascontiguousarray(np.copy(im)) + im = cv2.cvtColor(im, cv2.COLOR_RGB2BGR) + else: + im = np.ascontiguousarray(np.copy(im)) + + im_h, im_w = im.shape[:2] + text_scale = max(0.5, im.shape[0] / 3000.) + text_thickness = 1 + + line_inter = im.shape[0] / 40. + for i, res in enumerate(results): + if boxes is None: + text_w = 3 + text_h = 1 + elif is_mtmct: + box = boxes[i] # multi camera, bbox shape is x,y, w,h + text_w = int(box[0]) + 3 + text_h = int(box[1]) + else: + box = boxes[i] # single camera, bbox shape is 0, 0, x,y, w,h + text_w = int(box[2]) + 3 + text_h = int(box[3]) + for text in res: + text_h += int(line_inter) + text_loc = (text_w, text_h) + cv2.putText( + im, + text, + text_loc, + cv2.FONT_ITALIC, + text_scale, (0, 255, 255), + thickness=text_thickness) + return im + + +def visualize_action(im, + mot_boxes, + action_visual_collector=None, + action_text="", + video_action_score=None, + video_action_text=""): + im = cv2.imread(im) if isinstance(im, str) else im + im_h, im_w = im.shape[:2] + + text_scale = max(1, im.shape[1] / 400.) + text_thickness = 2 + + if action_visual_collector: + id_action_dict = {} + for collector, action_type in zip(action_visual_collector, action_text): + id_detected = collector.get_visualize_ids() + for pid in id_detected: + id_action_dict[pid] = id_action_dict.get(pid, []) + id_action_dict[pid].append(action_type) + for mot_box in mot_boxes: + # mot_box is a format with [mot_id, class, score, xmin, ymin, w, h] + if mot_box[0] in id_action_dict: + text_position = (int(mot_box[3] + mot_box[5] * 0.75), + int(mot_box[4] - 10)) + display_text = ', '.join(id_action_dict[mot_box[0]]) + cv2.putText(im, display_text, text_position, + cv2.FONT_HERSHEY_PLAIN, text_scale, (0, 0, 255), 2) + + if video_action_score: + cv2.putText( + im, + video_action_text + ': %.2f' % video_action_score, + (int(im_w / 2), int(15 * text_scale) + 5), + cv2.FONT_ITALIC, + text_scale, (0, 0, 255), + thickness=text_thickness) + + return im + + +def visualize_vehicleplate(im, results, boxes=None): + if isinstance(im, str): + im = Image.open(im) + im = np.ascontiguousarray(np.copy(im)) + im = cv2.cvtColor(im, cv2.COLOR_RGB2BGR) + else: + im = np.ascontiguousarray(np.copy(im)) + + im_h, im_w = im.shape[:2] + text_scale = max(1.0, im.shape[0] / 400.) + text_thickness = 2 + + line_inter = im.shape[0] / 40. + for i, res in enumerate(results): + if boxes is None: + text_w = 3 + text_h = 1 + else: + box = boxes[i] + text = res + if text == "": + continue + text_w = int(box[2]) + text_h = int(box[5] + box[3]) + text_loc = (text_w, text_h) + cv2.putText( + im, + "LP: " + text, + text_loc, + cv2.FONT_ITALIC, + text_scale, (0, 255, 255), + thickness=text_thickness) + return im + + +def draw_press_box_lanes(im, np_boxes, labels, threshold=0.5): + """ + Args: + im (PIL.Image.Image): PIL image + np_boxes (np.ndarray): shape:[N,6], N: number of box, + matix element:[class, score, x_min, y_min, x_max, y_max] + labels (list): labels:['class1', ..., 'classn'] + threshold (float): threshold of box + Returns: + im (PIL.Image.Image): visualized image + """ + + if isinstance(im, str): + im = Image.open(im).convert('RGB') + elif isinstance(im, np.ndarray): + im = Image.fromarray(im) + + draw_thickness = min(im.size) // 320 + draw = ImageDraw.Draw(im) + clsid2color = {} + color_list = get_color_map_list(len(labels)) + + if np_boxes.shape[1] == 7: + np_boxes = np_boxes[:, 1:] + + expect_boxes = (np_boxes[:, 1] > threshold) & (np_boxes[:, 0] > -1) + np_boxes = np_boxes[expect_boxes, :] + + for dt in np_boxes: + clsid, bbox, score = int(dt[0]), dt[2:], dt[1] + if clsid not in clsid2color: + clsid2color[clsid] = color_list[clsid] + color = tuple(clsid2color[clsid]) + + if len(bbox) == 4: + xmin, ymin, xmax, ymax = bbox + # draw bbox + draw.line( + [(xmin, ymin), (xmin, ymax), (xmax, ymax), (xmax, ymin), + (xmin, ymin)], + width=draw_thickness, + fill=(0, 0, 255)) + elif len(bbox) == 8: + x1, y1, x2, y2, x3, y3, x4, y4 = bbox + draw.line( + [(x1, y1), (x2, y2), (x3, y3), (x4, y4), (x1, y1)], + width=2, + fill=color) + xmin = min(x1, x2, x3, x4) + ymin = min(y1, y2, y3, y4) + + # draw label + text = "{}".format(labels[clsid]) + tw, th = imagedraw_textsize_c(draw, text) + draw.rectangle( + [(xmin + 1, ymax - th), (xmin + tw + 1, ymax)], fill=color) + draw.text((xmin + 1, ymax - th), text, fill=(0, 0, 255)) + return im + + +def visualize_vehiclepress(im, results, threshold=0.5): + results = np.array(results) + labels = ['violation'] + im = draw_press_box_lanes(im, results, labels, threshold=threshold) + return im + + +def visualize_lane(im, lanes): + if isinstance(im, str): + im = Image.open(im).convert('RGB') + elif isinstance(im, np.ndarray): + im = Image.fromarray(im) + + draw_thickness = min(im.size) // 320 + draw = ImageDraw.Draw(im) + + if len(lanes) > 0: + for lane in lanes: + draw.line( + [(lane[0], lane[1]), (lane[2], lane[3])], + width=draw_thickness, + fill=(0, 0, 255)) + + return im + + +def visualize_vehicle_retrograde(im, mot_res, vehicle_retrograde_res): + if isinstance(im, str): + im = Image.open(im).convert('RGB') + elif isinstance(im, np.ndarray): + im = Image.fromarray(im) + + draw_thickness = min(im.size) // 320 + draw = ImageDraw.Draw(im) + + lane = vehicle_retrograde_res['fence_line'] + if lane is not None: + draw.line( + [(lane[0], lane[1]), (lane[2], lane[3])], + width=draw_thickness, + fill=(0, 0, 0)) + + mot_id = vehicle_retrograde_res['output'] + if mot_id is None or len(mot_id) == 0: + return im + + if mot_res is None: + return im + np_boxes = mot_res['boxes'] + + if np_boxes is not None: + for dt in np_boxes: + if dt[0] not in mot_id: + continue + bbox = dt[3:] + if len(bbox) == 4: + xmin, ymin, xmax, ymax = bbox + # draw bbox + draw.line( + [(xmin, ymin), (xmin, ymax), (xmax, ymax), (xmax, ymin), + (xmin, ymin)], + width=draw_thickness, + fill=(0, 255, 0)) + + # draw label + text = "retrograde" + tw, th = imagedraw_textsize_c(draw, text) + draw.rectangle( + [(xmax + 1, ymin - th), (xmax + tw + 1, ymin)], + fill=(0, 255, 0)) + draw.text((xmax + 1, ymin - th), text, fill=(0, 255, 0)) + + return im + + +COLORS = [ + (255, 0, 0), + (0, 255, 0), + (0, 0, 255), + (255, 255, 0), + (255, 0, 255), + (0, 255, 255), + (128, 255, 0), + (255, 128, 0), + (128, 0, 255), + (255, 0, 128), + (0, 128, 255), + (0, 255, 128), + (128, 255, 255), + (255, 128, 255), + (255, 255, 128), + (60, 180, 0), + (180, 60, 0), + (0, 60, 180), + (0, 180, 60), + (60, 0, 180), + (180, 0, 60), + (255, 0, 0), + (0, 255, 0), + (0, 0, 255), + (255, 255, 0), + (255, 0, 255), + (0, 255, 255), + (128, 255, 0), + (255, 128, 0), + (128, 0, 255), +] + + +def imshow_lanes(img, lanes, show=False, out_file=None, width=4): + lanes_xys = [] + for _, lane in enumerate(lanes): + xys = [] + for x, y in lane: + if x <= 0 or y <= 0: + continue + x, y = int(x), int(y) + xys.append((x, y)) + lanes_xys.append(xys) + lanes_xys.sort(key=lambda xys: xys[0][0] if len(xys) > 0 else 0) + + for idx, xys in enumerate(lanes_xys): + for i in range(1, len(xys)): + cv2.line(img, xys[i - 1], xys[i], COLORS[idx], thickness=width) + + if show: + cv2.imshow('view', img) + cv2.waitKey(0) + + if out_file: + if not os.path.exists(os.path.dirname(out_file)): + os.makedirs(os.path.dirname(out_file)) + cv2.imwrite(out_file, img) \ No newline at end of file diff --git a/magic-pdf.json b/magic-pdf.json new file mode 100644 index 0000000..574153e --- /dev/null +++ b/magic-pdf.json @@ -0,0 +1,62 @@ +{ + "bucket_info": { + "bucket-name-1": [ + "ak", + "sk", + "endpoint" + ], + "bucket-name-2": [ + "ak", + "sk", + "endpoint" + ] + }, + "models-dir": "/root/.cache/modelscope/hub/models/opendatalab/PDF-Extract-Kit-1___0/models", + "layoutreader-model-dir": "/root/.cache/modelscope/hub/models/ppaanngggg/layoutreader", + "device-mode": "cuda", + "layout-config": { + "model": "doclayout_yolo" + }, + "formula-config": { + "mfd_model": "yolo_v8_mfd", + "mfr_model": "unimernet_small", + "enable": false + }, + "table-config": { + "model": "rapid_table", + "sub_model": "slanet_plus", + "enable": false, + "max_time": 400 + }, + "latex-delimiter-config": { + "display": { + "left": "$$", + "right": "$$" + }, + "inline": { + "left": "$", + "right": "$" + } + }, + "llm-aided-config": { + "formula_aided": { + "api_key": "your_api_key", + "base_url": "https://dashscope.aliyuncs.com/compatible-mode/v1", + "model": "qwen2.5-7b-instruct", + "enable": false + }, + "text_aided": { + "api_key": "your_api_key", + "base_url": "https://dashscope.aliyuncs.com/compatible-mode/v1", + "model": "qwen2.5-7b-instruct", + "enable": false + }, + "title_aided": { + "api_key": "your_api_key", + "base_url": "https://dashscope.aliyuncs.com/compatible-mode/v1", + "model": "qwen2.5-32b-instruct", + "enable": false + } + }, + "config_version": "1.2.1" +} diff --git a/models/PaddleDetection/inference_model/picodet_lcnet_x1_0_fgd_layout_cdla_infer/infer_cfg.yml b/models/PaddleDetection/inference_model/picodet_lcnet_x1_0_fgd_layout_cdla_infer/infer_cfg.yml new file mode 100644 index 0000000..01a1a7c --- /dev/null +++ b/models/PaddleDetection/inference_model/picodet_lcnet_x1_0_fgd_layout_cdla_infer/infer_cfg.yml @@ -0,0 +1,48 @@ +mode: paddle +draw_threshold: 0.5 +metric: COCO +use_dynamic_shape: false +arch: PicoDet +min_subgraph_size: 3 +Preprocess: +- interp: 2 + keep_ratio: false + target_size: + - 800 + - 608 + type: Resize +- is_scale: true + mean: + - 0.485 + - 0.456 + - 0.406 + std: + - 0.229 + - 0.224 + - 0.225 + type: NormalizeImage +- type: Permute +- stride: 32 + type: PadStride +label_list: +- Text +- Title +- Figure +- Figure caption +- Table +- Table caption +- Header +- Footer +- Reference +- Equation +NMS: + keep_top_k: 100 + name: MultiClassNMS + nms_threshold: 0.5 + nms_top_k: 1000 + score_threshold: 0.3 +fpn_stride: +- 8 +- 16 +- 32 +- 64 diff --git a/models/PaddleDetection/inference_model/picodet_lcnet_x1_0_fgd_layout_cdla_infer/model.pdiparams b/models/PaddleDetection/inference_model/picodet_lcnet_x1_0_fgd_layout_cdla_infer/model.pdiparams new file mode 100644 index 0000000..956d52f Binary files /dev/null and b/models/PaddleDetection/inference_model/picodet_lcnet_x1_0_fgd_layout_cdla_infer/model.pdiparams differ diff --git a/models/PaddleDetection/inference_model/picodet_lcnet_x1_0_fgd_layout_cdla_infer/model.pdiparams.info b/models/PaddleDetection/inference_model/picodet_lcnet_x1_0_fgd_layout_cdla_infer/model.pdiparams.info new file mode 100644 index 0000000..960eaba Binary files /dev/null and b/models/PaddleDetection/inference_model/picodet_lcnet_x1_0_fgd_layout_cdla_infer/model.pdiparams.info differ diff --git a/models/PaddleDetection/inference_model/picodet_lcnet_x1_0_fgd_layout_cdla_infer/model.pdmodel b/models/PaddleDetection/inference_model/picodet_lcnet_x1_0_fgd_layout_cdla_infer/model.pdmodel new file mode 100644 index 0000000..2ba4d31 Binary files /dev/null and b/models/PaddleDetection/inference_model/picodet_lcnet_x1_0_fgd_layout_cdla_infer/model.pdmodel differ diff --git a/models/PaddleDetection/inference_model/picodet_lcnet_x1_0_fgd_layout_infer/infer_cfg.yml b/models/PaddleDetection/inference_model/picodet_lcnet_x1_0_fgd_layout_infer/infer_cfg.yml new file mode 100644 index 0000000..d80c7c2 --- /dev/null +++ b/models/PaddleDetection/inference_model/picodet_lcnet_x1_0_fgd_layout_infer/infer_cfg.yml @@ -0,0 +1,43 @@ +mode: paddle +draw_threshold: 0.5 +metric: COCO +use_dynamic_shape: false +arch: PicoDet +min_subgraph_size: 3 +Preprocess: +- interp: 2 + keep_ratio: false + target_size: + - 800 + - 608 + type: Resize +- is_scale: true + mean: + - 0.485 + - 0.456 + - 0.406 + std: + - 0.229 + - 0.224 + - 0.225 + type: NormalizeImage +- type: Permute +- stride: 32 + type: PadStride +label_list: +- text +- title +- list +- table +- figure +NMS: + keep_top_k: 100 + name: MultiClassNMS + nms_threshold: 0.5 + nms_top_k: 1000 + score_threshold: 0.3 +fpn_stride: +- 8 +- 16 +- 32 +- 64 diff --git a/models/PaddleDetection/inference_model/picodet_lcnet_x1_0_fgd_layout_infer/model.pdiparams b/models/PaddleDetection/inference_model/picodet_lcnet_x1_0_fgd_layout_infer/model.pdiparams new file mode 100644 index 0000000..64be068 Binary files /dev/null and b/models/PaddleDetection/inference_model/picodet_lcnet_x1_0_fgd_layout_infer/model.pdiparams differ diff --git a/models/PaddleDetection/inference_model/picodet_lcnet_x1_0_fgd_layout_infer/model.pdiparams.info b/models/PaddleDetection/inference_model/picodet_lcnet_x1_0_fgd_layout_infer/model.pdiparams.info new file mode 100644 index 0000000..960eaba Binary files /dev/null and b/models/PaddleDetection/inference_model/picodet_lcnet_x1_0_fgd_layout_infer/model.pdiparams.info differ diff --git a/models/PaddleDetection/inference_model/picodet_lcnet_x1_0_fgd_layout_infer/model.pdmodel b/models/PaddleDetection/inference_model/picodet_lcnet_x1_0_fgd_layout_infer/model.pdmodel new file mode 100644 index 0000000..6f6d382 Binary files /dev/null and b/models/PaddleDetection/inference_model/picodet_lcnet_x1_0_fgd_layout_infer/model.pdmodel differ diff --git a/models/PaddleDetection/inference_model/picodet_lcnet_x1_0_fgd_layout_table_infer/infer_cfg.yml b/models/PaddleDetection/inference_model/picodet_lcnet_x1_0_fgd_layout_table_infer/infer_cfg.yml new file mode 100644 index 0000000..bc6bba7 --- /dev/null +++ b/models/PaddleDetection/inference_model/picodet_lcnet_x1_0_fgd_layout_table_infer/infer_cfg.yml @@ -0,0 +1,39 @@ +mode: paddle +draw_threshold: 0.5 +metric: COCO +use_dynamic_shape: false +arch: PicoDet +min_subgraph_size: 3 +Preprocess: +- interp: 2 + keep_ratio: false + target_size: + - 800 + - 608 + type: Resize +- is_scale: true + mean: + - 0.485 + - 0.456 + - 0.406 + std: + - 0.229 + - 0.224 + - 0.225 + type: NormalizeImage +- type: Permute +- stride: 32 + type: PadStride +label_list: +- table +NMS: + keep_top_k: 100 + name: MultiClassNMS + nms_threshold: 0.5 + nms_top_k: 1000 + score_threshold: 0.3 +fpn_stride: +- 8 +- 16 +- 32 +- 64 diff --git a/models/PaddleDetection/inference_model/picodet_lcnet_x1_0_fgd_layout_table_infer/model.pdiparams b/models/PaddleDetection/inference_model/picodet_lcnet_x1_0_fgd_layout_table_infer/model.pdiparams new file mode 100644 index 0000000..cfd220e Binary files /dev/null and b/models/PaddleDetection/inference_model/picodet_lcnet_x1_0_fgd_layout_table_infer/model.pdiparams differ diff --git a/models/PaddleDetection/inference_model/picodet_lcnet_x1_0_fgd_layout_table_infer/model.pdiparams.info b/models/PaddleDetection/inference_model/picodet_lcnet_x1_0_fgd_layout_table_infer/model.pdiparams.info new file mode 100644 index 0000000..960eaba Binary files /dev/null and b/models/PaddleDetection/inference_model/picodet_lcnet_x1_0_fgd_layout_table_infer/model.pdiparams.info differ diff --git a/models/PaddleDetection/inference_model/picodet_lcnet_x1_0_fgd_layout_table_infer/model.pdmodel b/models/PaddleDetection/inference_model/picodet_lcnet_x1_0_fgd_layout_table_infer/model.pdmodel new file mode 100644 index 0000000..50135d1 Binary files /dev/null and b/models/PaddleDetection/inference_model/picodet_lcnet_x1_0_fgd_layout_table_infer/model.pdmodel differ diff --git a/models/PaddleDetection/inference_model/picodet_lcnet_x1_0_layout_infer/infer_cfg.yml b/models/PaddleDetection/inference_model/picodet_lcnet_x1_0_layout_infer/infer_cfg.yml new file mode 100644 index 0000000..d80c7c2 --- /dev/null +++ b/models/PaddleDetection/inference_model/picodet_lcnet_x1_0_layout_infer/infer_cfg.yml @@ -0,0 +1,43 @@ +mode: paddle +draw_threshold: 0.5 +metric: COCO +use_dynamic_shape: false +arch: PicoDet +min_subgraph_size: 3 +Preprocess: +- interp: 2 + keep_ratio: false + target_size: + - 800 + - 608 + type: Resize +- is_scale: true + mean: + - 0.485 + - 0.456 + - 0.406 + std: + - 0.229 + - 0.224 + - 0.225 + type: NormalizeImage +- type: Permute +- stride: 32 + type: PadStride +label_list: +- text +- title +- list +- table +- figure +NMS: + keep_top_k: 100 + name: MultiClassNMS + nms_threshold: 0.5 + nms_top_k: 1000 + score_threshold: 0.3 +fpn_stride: +- 8 +- 16 +- 32 +- 64 diff --git a/models/PaddleDetection/inference_model/picodet_lcnet_x1_0_layout_infer/model.pdiparams b/models/PaddleDetection/inference_model/picodet_lcnet_x1_0_layout_infer/model.pdiparams new file mode 100644 index 0000000..84a3e8b Binary files /dev/null and b/models/PaddleDetection/inference_model/picodet_lcnet_x1_0_layout_infer/model.pdiparams differ diff --git a/models/PaddleDetection/inference_model/picodet_lcnet_x1_0_layout_infer/model.pdiparams.info b/models/PaddleDetection/inference_model/picodet_lcnet_x1_0_layout_infer/model.pdiparams.info new file mode 100644 index 0000000..960eaba Binary files /dev/null and b/models/PaddleDetection/inference_model/picodet_lcnet_x1_0_layout_infer/model.pdiparams.info differ diff --git a/models/PaddleDetection/inference_model/picodet_lcnet_x1_0_layout_infer/model.pdmodel b/models/PaddleDetection/inference_model/picodet_lcnet_x1_0_layout_infer/model.pdmodel new file mode 100644 index 0000000..56a692c Binary files /dev/null and b/models/PaddleDetection/inference_model/picodet_lcnet_x1_0_layout_infer/model.pdmodel differ diff --git a/pipeline.py b/pipeline.py new file mode 100644 index 0000000..ccf62e4 --- /dev/null +++ b/pipeline.py @@ -0,0 +1,94 @@ +from dotenv import load_dotenv +import os + +env = os.environ.get('env', 'dev') +load_dotenv(dotenv_path='.env.dev' if env == 'dev' else '.env', override=True) + +import time +import traceback +import cv2 +from helper.image_helper import pdf2image, image_orient_cls, page_detection_visual +from helper.page_detection.main import layout_analysis +from helper.content_recognition.main import rec +from helper.db_helper import insert_pdf2md_table +import tempfile +from loguru import logger +import datetime +import shutil + + +def _pdf2markdown_pipeline(pdf_path, tmp_dir): + start_time = time.time() + # 1. pdf -> images + t1 = time.time() + pdf2image(pdf_path, tmp_dir) + t2 = time.time() + + # 2. 图片方向分类 + t3 = time.time() + orient_cls_results = image_orient_cls(tmp_dir) + t4 = time.time() + for r in orient_cls_results: + clsid = r[0]['class_ids'][0] + filename = r[0]['filename'] + if clsid == 1 or clsid == 3: + img = cv2.imread(filename) + img = cv2.rotate(img, cv2.ROTATE_90_CLOCKWISE) + cv2.imwrite(filename, img) + + filepaths = os.listdir(tmp_dir) + filepaths.sort(key=lambda x: int(x.split('/')[-1].split('.')[0])) + filepaths = [f'{tmp_dir}/{_}' for _ in filepaths] + + # filepaths = filepaths[:75] + + # 3. 版面分析 + t5 = time.time() + layout_detection_results = layout_analysis(filepaths) + t6 = time.time() + + # 3.1 visual + if int(os.environ['VISUAL']): + visual_dir = './visual_images' + for f in os.listdir(visual_dir): + if f.endswith('.jpg'): + os.remove(f'{visual_dir}/{f}') + for i in range(len(layout_detection_results)): + vis_img = page_detection_visual(layout_detection_results[i]) + cv2.imwrite(f'{visual_dir}/{i + 1}.jpg', vis_img) + + # 4. 内容识别 + t7 = time.time() + layout_recognition_results = rec(layout_detection_results, tmp_dir) + t8 = time.time() + + end_time = time.time() + logger.info(f'{pdf_path} analysis completed in {round(end_time - start_time, 3)} seconds, including {round(t2 - t1, 3)} for pdf to image, {round(t4 - t3, 3)} second for image orient classification, {round(t6 - t5, 3)} seconds for page detection, and {round(t8 - t7, 3)} seconds for layout recognition, page number: {len(filepaths)}') + + return layout_recognition_results + + +def pdf2markdown_pipeline(pdf_path: str): + pdf_name = pdf_path.split('/')[-1] + start_time = datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S.%f') + process_status = 0 + tmp_dir = tempfile.mkdtemp() + try: + results = _pdf2markdown_pipeline(pdf_path, tmp_dir) + except Exception: + logger.error(f'analysis pdf error! \n{traceback.format_exc()}') + process_status = 3 + end_time = datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S.%f') + insert_pdf2md_table(pdf_path, pdf_name, process_status, start_time, end_time, None) + pdf_id = None + else: + process_status = 2 + end_time = datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S.%f') + pdf_id = insert_pdf2md_table(pdf_path, pdf_name, process_status, start_time, end_time, results) + finally: + shutil.rmtree(tmp_dir) + return process_status, pdf_id + + +if __name__ == '__main__': + pdf2markdown_pipeline('/mnt/pdf2markdown/demo.pdf') diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000..db03014 --- /dev/null +++ b/requirements.txt @@ -0,0 +1,189 @@ +albucore==0.0.24 +albumentations==2.0.6 +annotated-types==0.7.0 +anthropic==0.46.0 +antlr4-python3-runtime==4.9.3 +anyio==4.9.0 +astor==0.8.1 +babel==2.17.0 +bce-python-sdk==0.9.29 +beautifulsoup4==4.13.4 +blinker==1.9.0 +boto3==1.38.8 +botocore==1.38.8 +Brotli==1.1.0 +cachetools==5.5.2 +certifi==2025.4.26 +cffi==1.17.1 +cfgv==3.4.0 +charset-normalizer==3.4.1 +click==8.1.8 +cobble==0.1.4 +coloredlogs==15.0.1 +colorlog==6.9.0 +contourpy==1.3.2 +cryptography==44.0.3 +cssselect2==0.8.0 +cycler==0.12.1 +Cython==3.0.12 +decorator==5.2.1 +dill==0.4.0 +distlib==0.3.9 +distro==1.9.0 +doclayout_yolo==0.0.2b1 +easydict==1.13 +EbookLib==0.18 +et_xmlfile==2.0.0 +faiss-cpu==1.8.0.post1 +fast-langdetect==0.2.5 +fasttext-predict==0.9.2.4 +filelock==3.18.0 +filetype==1.2.0 +fire==0.7.0 +Flask==3.1.0 +flask-babel==4.0.0 +flatbuffers==25.2.10 +fonttools==4.57.0 +fsspec==2025.3.2 +ftfy==6.3.1 +future==1.0.0 +gast==0.3.3 +google-auth==2.39.0 +google-genai==1.13.0 +h11==0.16.0 +httpcore==1.0.9 +httpx==0.28.1 +huggingface-hub==0.30.2 +humanfriendly==10.0 +identify==2.6.10 +idna==3.10 +imageio==2.37.0 +imgaug==0.4.0 +itsdangerous==2.2.0 +Jinja2==3.1.6 +jiter==0.9.0 +jmespath==1.0.1 +joblib==1.4.2 +kiwisolver==1.4.8 +lazy_loader==0.4 +lmdb==1.6.2 +loguru==0.7.3 +lxml==5.4.0 +magic-pdf==1.3.10 +mammoth==1.9.0 +markdown2==2.5.3 +markdownify==0.13.1 +marker-pdf==1.6.2 +MarkupSafe==3.0.2 +matplotlib==3.10.1 +modelscope==1.25.0 +mpmath==1.3.0 +networkx==3.4.2 +nodeenv==1.9.1 +numpy==1.24.4 +nvidia-cublas-cu12==12.4.5.8 +nvidia-cuda-cupti-cu12==12.4.127 +nvidia-cuda-nvrtc-cu12==12.4.127 +nvidia-cuda-runtime-cu12==12.4.127 +nvidia-cudnn-cu12==9.1.0.70 +nvidia-cufft-cu12==11.2.1.3 +nvidia-curand-cu12==10.3.5.147 +nvidia-cusolver-cu12==11.6.1.9 +nvidia-cusparse-cu12==12.3.1.170 +nvidia-cusparselt-cu12==0.6.2 +nvidia-nccl-cu12==2.21.5 +nvidia-nvjitlink-cu12==12.4.127 +nvidia-nvtx-cu12==12.4.127 +omegaconf==2.3.0 +onnxruntime==1.21.1 +openai==1.77.0 +opencv-contrib-python==4.11.0.86 +opencv-python==4.6.0.66 +opencv-python-headless==4.11.0.86 +openpyxl==3.1.5 +opt-einsum==3.3.0 +packaging==25.0 +paddleclas==2.5.2 +paddleocr==2.10.0 +paddlepaddle-gpu==2.6.2 +pandas==2.2.3 +pdf2image==1.17.0 +pdfminer.six==20250324 +pdftext==0.6.2 +pillow==10.4.0 +platformdirs==4.3.7 +pre_commit==4.2.0 +prettytable==3.16.0 +protobuf==6.30.2 +psutil==7.0.0 +psycopg2==2.9.10 +py-cpuinfo==9.0.0 +pyasn1==0.6.1 +pyasn1_modules==0.4.2 +pyclipper==1.3.0.post6 +pycparser==2.22 +pycryptodome==3.22.0 +pydantic==2.10.6 +pydantic-settings==2.9.1 +pydantic_core==2.27.2 +pydyf==0.11.0 +PyMuPDF==1.24.14 +pyparsing==3.2.3 +pypdfium2==4.30.0 +pyphen==0.17.2 +python-dateutil==2.9.0.post0 +python-docx==1.1.2 +python-dotenv==1.1.0 +python-pptx==1.0.2 +pytz==2025.2 +PyYAML==6.0.2 +rapid-table==1.0.5 +RapidFuzz==3.13.0 +rapidocr==2.0.7 +rarfile==4.2 +regex==2024.11.6 +requests==2.32.3 +robust-downloader==0.0.2 +rsa==4.9.1 +s3transfer==0.12.0 +safetensors==0.5.3 +scikit-image==0.25.2 +scikit-learn==1.6.1 +scipy==1.15.2 +seaborn==0.13.2 +shapely==2.1.0 +simsimd==6.2.1 +six==1.17.0 +sniffio==1.3.1 +soupsieve==2.7 +stringzilla==3.12.5 +surya-ocr==0.13.1 +sympy==1.13.1 +termcolor==3.1.0 +thop==0.1.1.post2209072238 +threadpoolctl==3.6.0 +tifffile==2025.3.30 +tinycss2==1.4.0 +tinyhtml5==2.0.0 +tokenizers==0.21.1 +torch==2.6.0 +torchvision==0.21.0 +tqdm==4.67.1 +transformers==4.51.3 +triton==3.2.0 +typing-inspection==0.4.0 +typing_extensions==4.13.2 +tzdata==2025.2 +ujson==5.10.0 +ultralytics==8.3.127 +ultralytics-thop==2.0.14 +urllib3==2.4.0 +virtualenv==20.31.1 +visualdl==2.5.3 +wcwidth==0.2.13 +weasyprint==63.1 +webencodings==0.5.1 +websockets==15.0.1 +Werkzeug==3.1.3 +XlsxWriter==3.2.3 +zopfli==0.2.3.post1 diff --git a/server.py b/server.py new file mode 100644 index 0000000..17d28ae --- /dev/null +++ b/server.py @@ -0,0 +1,16 @@ +from flask import Flask, request +import requests +from pipeline import pdf2markdown_pipeline + + +app = Flask(__name__) + + +@app.route('/pdf-qa-server/pdf-to-md') +def pdf2markdown(): + data = request.json + pdf_paths = data['pathList'] + callback_url = data['webhookUrl'] + for pdf_path in pdf_paths: + process_status, pdf_id = pdf2markdown_pipeline(pdf_path) + requests.post(callback_url, json={'pdfId': pdf_id, 'processStatus': process_status}) diff --git a/visual_images/.gitkeep b/visual_images/.gitkeep new file mode 100644 index 0000000..e69de29