first commit

pull/2/head
zhangzhichao 5 months ago
parent c7a0a4a452
commit 76e53818ce

@ -0,0 +1,7 @@
POSTGRESQL_HOST=
POSTGRESQL_PORT=
POSTGRESQL_USERNAME=
POSTGRESQL_PASSWORD=
POSTGRESQL_DATABASE=
VISUAL=0

@ -0,0 +1,7 @@
POSTGRESQL_HOST=192.168.10.137
POSTGRESQL_PORT=54321
POSTGRESQL_USERNAME=postgres
POSTGRESQL_PASSWORD=123456
POSTGRESQL_DATABASE=pdf-qa
VISUAL=1

5
.gitignore vendored

@ -0,0 +1,5 @@
venv
*.pdf
.vscode
visual_images/*.jpg
__pycache__

@ -0,0 +1,21 @@
# 环境说明
```
# 1. pdf转图片需要安装以下依赖
apt install -y poppler-utils
# 2. 解决paddle的兼容问题
export FLAGS_enable_pir_api=0
# 3. 安装paddlepaddle-gpu后需要配置的环境
apt install libcudnn8
apt install libcudnn8-dev
echo "export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/usr/local/lib/python3.11/site-packages/nvidia/cublas/lib/" >> /etc/profile
# 4. 手动安装MinerU需要用到的模型下载路径:
# model_dir is: /root/.cache/modelscope/hub/models/opendatalab/PDF-Extract-Kit-1___0/models
# layoutreader_model_dir is: /root/.cache/modelscope/hub/models/ppaanngggg/layoutreader
# The configuration file has been configured successfully, the path is: /root/magic-pdf.json
# 需要将项目中的magic-pdf.json链接到/root/magic-pdf.json
ln -s `pwd`/magic-pdf.json /root/magic-pdf.json
python download_MinerU_models.py
# 5. python连接postgresql需要下载的依赖
apt install postgresql postgresql-contrib libpq-dev
```

@ -0,0 +1,67 @@
import json
import shutil
import os
import requests
from modelscope import snapshot_download
def download_json(url):
# 下载JSON文件
response = requests.get(url)
response.raise_for_status() # 检查请求是否成功
return response.json()
def download_and_modify_json(url, local_filename, modifications):
if os.path.exists(local_filename):
data = json.load(open(local_filename))
config_version = data.get('config_version', '0.0.0')
if config_version < '1.2.0':
data = download_json(url)
else:
data = download_json(url)
# 修改内容
for key, value in modifications.items():
data[key] = value
# 保存修改后的内容
with open(local_filename, 'w', encoding='utf-8') as f:
json.dump(data, f, ensure_ascii=False, indent=4)
if __name__ == '__main__':
mineru_patterns = [
# "models/Layout/LayoutLMv3/*",
"models/Layout/YOLO/*",
"models/MFD/YOLO/*",
"models/MFR/unimernet_hf_small_2503/*",
"models/OCR/paddleocr_torch/*",
# "models/TabRec/TableMaster/*",
# "models/TabRec/StructEqTable/*",
]
model_dir = snapshot_download('opendatalab/PDF-Extract-Kit-1.0', allow_patterns=mineru_patterns)
layoutreader_model_dir = snapshot_download('ppaanngggg/layoutreader')
model_dir = model_dir + '/models'
print(f'model_dir is: {model_dir}')
print(f'layoutreader_model_dir is: {layoutreader_model_dir}')
# paddleocr_model_dir = model_dir + '/OCR/paddleocr'
# user_paddleocr_dir = os.path.expanduser('~/.paddleocr')
# if os.path.exists(user_paddleocr_dir):
# shutil.rmtree(user_paddleocr_dir)
# shutil.copytree(paddleocr_model_dir, user_paddleocr_dir)
json_url = 'https://gcore.jsdelivr.net/gh/opendatalab/MinerU@master/magic-pdf.template.json'
config_file_name = 'magic-pdf.json'
home_dir = os.path.expanduser('~')
config_file = os.path.join(home_dir, config_file_name)
json_mods = {
'models-dir': model_dir,
'layoutreader-model-dir': layoutreader_model_dir,
}
download_and_modify_json(json_url, config_file, json_mods)
print(f'The configuration file has been configured successfully, the path is: {config_file}')

@ -0,0 +1,116 @@
from typing import List
import cv2
from .utils import scanning_document_classify, text_rec, table_rec, scanning_document_rec, markdown_rec, assign_tables_to_titles, remove_watermark
from tqdm import tqdm
class LayoutRecognitionResult(object):
def __init__(self, clsid, content, box, table_title=None):
self.clsid = clsid
self.content = content
self.box = box
self.table_title = table_title
def __repr__(self):
return f"[{self.clsid}] {self.content}"
expand_pixel = 10
def rec(page_detection_results, tmp_dir) -> List[List[LayoutRecognitionResult]]:
page_recognition_results = []
for page_idx in tqdm(range(len(page_detection_results)), '文本识别'):
results = page_detection_results[page_idx]
if not results.boxes:
page_recognition_results.append([])
continue
img = cv2.imread(results.image_path)
h, w = img.shape[:2]
for layout in results.boxes:
# box往外扩一点便于ocr
layout.pos[0] -= expand_pixel
layout.pos[1] -= expand_pixel
layout.pos[2] += expand_pixel
layout.pos[3] += expand_pixel
layout.pos[0] = max(0, layout.pos[0])
layout.pos[1] = max(0, layout.pos[1])
layout.pos[2] = min(w, layout.pos[2])
layout.pos[3] = min(h, layout.pos[3])
outputs = []
is_scanning_document = False
for layout in results.boxes:
x1, y1, x2, y2 = layout.pos
x1, y1, x2, y2 = int(x1), int(y1), int(x2), int(y2)
layout_img = img[y1: y2, x1: x2]
content = None
if layout.clsid == 0:
# text
content = markdown_rec(layout_img)
elif layout.clsid == 2:
# figure
if scanning_document_classify(layout_img):
# 扫描件
is_scanning_document = True
content, layout_img = scanning_document_rec(layout_img)
source_page_no_watermark_img = remove_watermark(cv2.imread(f'{tmp_dir}/{page_idx + 1}.jpg'))
elif layout.clsid == 4:
# table
if scanning_document_classify(layout_img):
is_scanning_document = True
content, layout_img = scanning_document_rec(layout_img)
source_page_no_watermark_img = remove_watermark(cv2.imread(f'{tmp_dir}/{page_idx + 1}.jpg'))
else:
content = table_rec(layout_img)
elif layout.clsid == 5:
# table caption
ocr_results = text_rec(layout_img)
content = ''
for o in ocr_results:
content += f'{o}\n'
while content.endswith('\n'):
content = content[:-1]
if not content:
continue
result = LayoutRecognitionResult(layout.clsid, content, layout.pos)
outputs.append(result)
if is_scanning_document and len(outputs) == 1:
# 扫描件额外提取标题
h, w = source_page_no_watermark_img.shape[:2]
if h > w:
title_img = source_page_no_watermark_img[:360, :w, ...]
# cv2.imwrite(f'/mnt/pdf2markdown/temp/{page_idx + 1}.jpg', title_img)
# vis = cv2.rectangle(source_page_no_watermark_img.copy(), (0, 0), (w, 360), (255, 255, 0), 3)
# cv2.imwrite(f'/mnt/pdf2markdown/temp/{page_idx + 1}-vis.jpg', vis)
else:
title_img = source_page_no_watermark_img[:410, :w, ...]
# cv2.imwrite(f'/mnt/pdf2markdown/temp/{page_idx + 1}.jpg', title_img)
# vis = cv2.rectangle(source_page_no_watermark_img.copy(), (0, 310), (w, 410), (255, 255, 0), 3)
# cv2.imwrite(f'/mnt/pdf2markdown/temp/{page_idx + 1}-vis.jpg', vis)
title = text_rec(title_img)
outputs[0].table_title = '\n'.join(title)
else:
# 自动给表格分配距离它最近的标题
assign_tables_to_titles(outputs)
# 表格标题可以删掉了
outputs = [_ for _ in outputs if _.clsid != 5]
# 将2-图片 和 4-表格转为数据库中的枚举 1-表格
for o in outputs:
if o.clsid == 2 or o.clsid == 4:
o.clsid = 1
page_recognition_results.append(outputs)
return page_recognition_results

@ -0,0 +1,5 @@
# -*- encoding: utf-8 -*-
# @Author: SWHL
# @Contact: liekkaskono@163.com
from .main import RapidTable, RapidTableInput
from .utils import VisTable

@ -0,0 +1,258 @@
# -*- encoding: utf-8 -*-
# @Author: SWHL
# @Contact: liekkaskono@163.com
import argparse
import copy
import importlib
import time
from dataclasses import asdict, dataclass
from enum import Enum
from pathlib import Path
from typing import Dict, List, Optional, Tuple, Union
import cv2
import numpy as np
from rapid_table.utils import DownloadModel, LoadImage, Logger, VisTable
from .table_matcher import TableMatch
from .table_structure import TableStructurer, TableStructureUnitable
from markdownify import markdownify as md
logger = Logger(logger_name=__name__).get_log()
root_dir = Path(__file__).resolve().parent
class ModelType(Enum):
PPSTRUCTURE_EN = "ppstructure_en"
PPSTRUCTURE_ZH = "ppstructure_zh"
SLANETPLUS = "slanet_plus"
UNITABLE = "unitable"
ROOT_URL = "https://www.modelscope.cn/models/RapidAI/RapidTable/resolve/master/"
KEY_TO_MODEL_URL = {
ModelType.PPSTRUCTURE_EN.value: f"{ROOT_URL}/en_ppstructure_mobile_v2_SLANet.onnx",
ModelType.PPSTRUCTURE_ZH.value: f"{ROOT_URL}/ch_ppstructure_mobile_v2_SLANet.onnx",
ModelType.SLANETPLUS.value: f"{ROOT_URL}/slanet-plus.onnx",
ModelType.UNITABLE.value: {
"encoder": f"{ROOT_URL}/unitable/encoder.pth",
"decoder": f"{ROOT_URL}/unitable/decoder.pth",
"vocab": f"{ROOT_URL}/unitable/vocab.json",
},
}
@dataclass
class RapidTableInput:
model_type: Optional[str] = ModelType.SLANETPLUS.value
model_path: Union[str, Path, None, Dict[str, str]] = None
use_cuda: bool = False
device: str = "cpu"
@dataclass
class RapidTableOutput:
pred_html: Optional[str] = None
cell_bboxes: Optional[np.ndarray] = None
logic_points: Optional[np.ndarray] = None
elapse: Optional[float] = None
class RapidTable:
def __init__(self, config: RapidTableInput):
self.model_type = config.model_type
if self.model_type not in KEY_TO_MODEL_URL:
model_list = ",".join(KEY_TO_MODEL_URL)
raise ValueError(
f"{self.model_type} is not supported. The currently supported models are {model_list}."
)
config.model_path = self.get_model_path(config.model_type, config.model_path)
if self.model_type == ModelType.UNITABLE.value:
self.table_structure = TableStructureUnitable(asdict(config))
else:
self.table_structure = TableStructurer(asdict(config))
self.table_matcher = TableMatch()
try:
self.ocr_engine = importlib.import_module("rapidocr").RapidOCR()
except ModuleNotFoundError:
self.ocr_engine = None
self.load_img = LoadImage()
def __call__(
self,
img_content: Union[str, np.ndarray, bytes, Path],
ocr_result: List[Union[List[List[float]], str, str]] = None,
) -> RapidTableOutput:
if self.ocr_engine is None and ocr_result is None:
raise ValueError(
"One of two conditions must be met: ocr_result is not empty, or rapidocr is installed."
)
img = self.load_img(img_content)
s = time.perf_counter()
h, w = img.shape[:2]
if ocr_result is None:
ocr_result = self.ocr_engine(img)
ocr_result = list(
zip(
ocr_result.boxes,
ocr_result.txts,
ocr_result.scores,
)
)
dt_boxes, rec_res = self.get_boxes_recs(ocr_result, h, w)
pred_structures, cell_bboxes, _ = self.table_structure(copy.deepcopy(img))
# 适配slanet-plus模型输出的box缩放还原
if self.model_type == ModelType.SLANETPLUS.value:
cell_bboxes = self.adapt_slanet_plus(img, cell_bboxes)
pred_html = self.table_matcher(pred_structures, cell_bboxes, dt_boxes, rec_res)
# 过滤掉占位的bbox
mask = ~np.all(cell_bboxes == 0, axis=1)
cell_bboxes = cell_bboxes[mask]
logic_points = self.table_matcher.decode_logic_points(pred_structures)
elapse = time.perf_counter() - s
return RapidTableOutput(pred_html, cell_bboxes, logic_points, elapse)
def get_boxes_recs(
self, ocr_result: List[Union[List[List[float]], str, str]], h: int, w: int
) -> Tuple[np.ndarray, Tuple[str, str]]:
dt_boxes, rec_res, scores = list(zip(*ocr_result))
rec_res = list(zip(rec_res, scores))
r_boxes = []
for box in dt_boxes:
box = np.array(box)
x_min = max(0, box[:, 0].min() - 1)
x_max = min(w, box[:, 0].max() + 1)
y_min = max(0, box[:, 1].min() - 1)
y_max = min(h, box[:, 1].max() + 1)
box = [x_min, y_min, x_max, y_max]
r_boxes.append(box)
dt_boxes = np.array(r_boxes)
return dt_boxes, rec_res
def adapt_slanet_plus(self, img: np.ndarray, cell_bboxes: np.ndarray) -> np.ndarray:
h, w = img.shape[:2]
resized = 488
ratio = min(resized / h, resized / w)
w_ratio = resized / (w * ratio)
h_ratio = resized / (h * ratio)
cell_bboxes[:, 0::2] *= w_ratio
cell_bboxes[:, 1::2] *= h_ratio
return cell_bboxes
@staticmethod
def get_model_path(
model_type: str, model_path: Union[str, Path, None]
) -> Union[str, Dict[str, str]]:
if model_path is not None:
return model_path
model_url = KEY_TO_MODEL_URL.get(model_type, None)
if isinstance(model_url, str):
model_path = DownloadModel.download(model_url)
return model_path
if isinstance(model_url, dict):
model_paths = {}
for k, url in model_url.items():
model_paths[k] = DownloadModel.download(
url, save_model_name=f"{model_type}_{Path(url).name}"
)
return model_paths
raise ValueError(f"Model URL: {type(model_url)} is not between str and dict.")
def parse_args(arg_list: Optional[List[str]] = None):
parser = argparse.ArgumentParser()
parser.add_argument(
"-v",
"--vis",
action="store_true",
default=False,
help="Wheter to visualize the layout results.",
)
parser.add_argument(
"-img", "--img_path", type=str, required=True, help="Path to image for layout."
)
parser.add_argument(
"-m",
"--model_type",
type=str,
default=ModelType.SLANETPLUS.value,
choices=list(KEY_TO_MODEL_URL),
)
args = parser.parse_args(arg_list)
return args
try:
ocr_engine = importlib.import_module("rapidocr").RapidOCR()
except ModuleNotFoundError as exc:
raise ModuleNotFoundError(
"Please install the rapidocr by pip install rapidocr"
) from exc
input_args = RapidTableInput(model_type=ModelType.SLANETPLUS.value)
table_engine = RapidTable(input_args)
def table2md_pipeline(img):
rapid_ocr_output = ocr_engine(img)
ocr_result = list(
zip(rapid_ocr_output.boxes, rapid_ocr_output.txts, rapid_ocr_output.scores)
)
table_results = table_engine(img, ocr_result)
html_content = table_results.pred_html
md_content = md(html_content)
return md_content
# def main(arg_list: Optional[List[str]] = None):
# args = parse_args(arg_list)
# try:
# ocr_engine = importlib.import_module("rapidocr").RapidOCR()
# except ModuleNotFoundError as exc:
# raise ModuleNotFoundError(
# "Please install the rapidocr by pip install rapidocr"
# ) from exc
# input_args = RapidTableInput(model_type=args.model_type)
# table_engine = RapidTable(input_args)
# img = cv2.imread(args.img_path)
# rapid_ocr_output = ocr_engine(img)
# ocr_result = list(
# zip(rapid_ocr_output.boxes, rapid_ocr_output.txts, rapid_ocr_output.scores)
# )
# table_results = table_engine(img, ocr_result)
# print(table_results.pred_html)
# viser = VisTable()
# if args.vis:
# img_path = Path(args.img_path)
# save_dir = img_path.resolve().parent
# save_html_path = save_dir / f"{Path(img_path).stem}.html"
# save_drawed_path = save_dir / f"vis_{Path(img_path).name}"
# viser(img_path, table_results, save_html_path, save_drawed_path)
if __name__ == "__main__":
res = table2md_pipeline(cv2.imread('/mnt/research/PaddleOCR/pdf2md_pipeline/s4_content_recognition/all_layouts/11.jpg'))
print('*' * 50)
print(res)

@ -0,0 +1,4 @@
# -*- encoding: utf-8 -*-
# @Author: SWHL
# @Contact: liekkaskono@163.com
from .matcher import TableMatch

@ -0,0 +1,199 @@
# copyright (c) 2022 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# -*- encoding: utf-8 -*-
import numpy as np
from .utils import compute_iou, distance
class TableMatch:
def __init__(self, filter_ocr_result=True, use_master=False):
self.filter_ocr_result = filter_ocr_result
self.use_master = use_master
def __call__(self, pred_structures, cell_bboxes, dt_boxes, rec_res):
if self.filter_ocr_result:
dt_boxes, rec_res = self._filter_ocr_result(cell_bboxes, dt_boxes, rec_res)
matched_index = self.match_result(dt_boxes, cell_bboxes)
pred_html, pred = self.get_pred_html(pred_structures, matched_index, rec_res)
return pred_html
def match_result(self, dt_boxes, cell_bboxes, min_iou=0.1**8):
matched = {}
for i, gt_box in enumerate(dt_boxes):
distances = []
for j, pred_box in enumerate(cell_bboxes):
if len(pred_box) == 8:
pred_box = [
np.min(pred_box[0::2]),
np.min(pred_box[1::2]),
np.max(pred_box[0::2]),
np.max(pred_box[1::2]),
]
distances.append(
(distance(gt_box, pred_box), 1.0 - compute_iou(gt_box, pred_box))
) # compute iou and l1 distance
sorted_distances = distances.copy()
# select det box by iou and l1 distance
sorted_distances = sorted(
sorted_distances, key=lambda item: (item[1], item[0])
)
# must > min_iou
if sorted_distances[0][1] >= 1 - min_iou:
continue
if distances.index(sorted_distances[0]) not in matched:
matched[distances.index(sorted_distances[0])] = [i]
else:
matched[distances.index(sorted_distances[0])].append(i)
return matched
def get_pred_html(self, pred_structures, matched_index, ocr_contents):
end_html = []
td_index = 0
for tag in pred_structures:
if "</td>" not in tag:
end_html.append(tag)
continue
if "<td></td>" == tag:
end_html.extend("<td>")
if td_index in matched_index.keys():
b_with = False
if (
"<b>" in ocr_contents[matched_index[td_index][0]]
and len(matched_index[td_index]) > 1
):
b_with = True
end_html.extend("<b>")
for i, td_index_index in enumerate(matched_index[td_index]):
content = ocr_contents[td_index_index][0]
if len(matched_index[td_index]) > 1:
if len(content) == 0:
continue
if content[0] == " ":
content = content[1:]
if "<b>" in content:
content = content[3:]
if "</b>" in content:
content = content[:-4]
if len(content) == 0:
continue
if i != len(matched_index[td_index]) - 1 and " " != content[-1]:
content += " "
end_html.extend(content)
if b_with:
end_html.extend("</b>")
if "<td></td>" == tag:
end_html.append("</td>")
else:
end_html.append(tag)
td_index += 1
# Filter <thead></thead><tbody></tbody> elements
filter_elements = ["<thead>", "</thead>", "<tbody>", "</tbody>"]
end_html = [v for v in end_html if v not in filter_elements]
return "".join(end_html), end_html
def decode_logic_points(self, pred_structures):
logic_points = []
current_row = 0
current_col = 0
max_rows = 0
max_cols = 0
occupied_cells = {} # 用于记录已经被占用的单元格
def is_occupied(row, col):
return (row, col) in occupied_cells
def mark_occupied(row, col, rowspan, colspan):
for r in range(row, row + rowspan):
for c in range(col, col + colspan):
occupied_cells[(r, c)] = True
i = 0
while i < len(pred_structures):
token = pred_structures[i]
if token == "<tr>":
current_col = 0 # 每次遇到 <tr> 时,重置当前列号
elif token == "</tr>":
current_row += 1 # 行结束,行号增加
elif token.startswith("<td"):
colspan = 1
rowspan = 1
j = i
if token != "<td></td>":
j += 1
# 提取 colspan 和 rowspan 属性
while j < len(pred_structures) and not pred_structures[
j
].startswith(">"):
if "colspan=" in pred_structures[j]:
colspan = int(pred_structures[j].split("=")[1].strip("\"'"))
elif "rowspan=" in pred_structures[j]:
rowspan = int(pred_structures[j].split("=")[1].strip("\"'"))
j += 1
# 跳过已经处理过的属性 token
i = j
# 找到下一个未被占用的列
while is_occupied(current_row, current_col):
current_col += 1
# 计算逻辑坐标
r_start = current_row
r_end = current_row + rowspan - 1
col_start = current_col
col_end = current_col + colspan - 1
# 记录逻辑坐标
logic_points.append([r_start, r_end, col_start, col_end])
# 标记占用的单元格
mark_occupied(r_start, col_start, rowspan, colspan)
# 更新当前列号
current_col += colspan
# 更新最大行数和列数
max_rows = max(max_rows, r_end + 1)
max_cols = max(max_cols, col_end + 1)
i += 1
return logic_points
def _filter_ocr_result(self, cell_bboxes, dt_boxes, rec_res):
y1 = cell_bboxes[:, 1::2].min()
new_dt_boxes = []
new_rec_res = []
for box, rec in zip(dt_boxes, rec_res):
if np.max(box[1::2]) < y1:
continue
new_dt_boxes.append(box)
new_rec_res.append(rec)
return new_dt_boxes, new_rec_res

@ -0,0 +1,249 @@
# copyright (c) 2022 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# -*- encoding: utf-8 -*-
# @Author: SWHL
# @Contact: liekkaskono@163.com
import copy
import re
def deal_isolate_span(thead_part):
"""
Deal with isolate span cases in this function.
It causes by wrong prediction in structure recognition model.
eg. predict <td rowspan="2"></td> to <td></td> rowspan="2"></b></td>.
:param thead_part:
:return:
"""
# 1. find out isolate span tokens.
isolate_pattern = (
'<td></td> rowspan="(\d)+" colspan="(\d)+"></b></td>|'
'<td></td> colspan="(\d)+" rowspan="(\d)+"></b></td>|'
'<td></td> rowspan="(\d)+"></b></td>|'
'<td></td> colspan="(\d)+"></b></td>'
)
isolate_iter = re.finditer(isolate_pattern, thead_part)
isolate_list = [i.group() for i in isolate_iter]
# 2. find out span number, by step 1 results.
span_pattern = (
' rowspan="(\d)+" colspan="(\d)+"|'
' colspan="(\d)+" rowspan="(\d)+"|'
' rowspan="(\d)+"|'
' colspan="(\d)+"'
)
corrected_list = []
for isolate_item in isolate_list:
span_part = re.search(span_pattern, isolate_item)
spanStr_in_isolateItem = span_part.group()
# 3. merge the span number into the span token format string.
if spanStr_in_isolateItem is not None:
corrected_item = f"<td{spanStr_in_isolateItem}></td>"
corrected_list.append(corrected_item)
else:
corrected_list.append(None)
# 4. replace original isolated token.
for corrected_item, isolate_item in zip(corrected_list, isolate_list):
if corrected_item is not None:
thead_part = thead_part.replace(isolate_item, corrected_item)
else:
pass
return thead_part
def deal_duplicate_bb(thead_part):
"""
Deal duplicate <b> or </b> after replace.
Keep one <b></b> in a <td></td> token.
:param thead_part:
:return:
"""
# 1. find out <td></td> in <thead></thead>.
td_pattern = (
'<td rowspan="(\d)+" colspan="(\d)+">(.+?)</td>|'
'<td colspan="(\d)+" rowspan="(\d)+">(.+?)</td>|'
'<td rowspan="(\d)+">(.+?)</td>|'
'<td colspan="(\d)+">(.+?)</td>|'
"<td>(.*?)</td>"
)
td_iter = re.finditer(td_pattern, thead_part)
td_list = [t.group() for t in td_iter]
# 2. is multiply <b></b> in <td></td> or not?
new_td_list = []
for td_item in td_list:
if td_item.count("<b>") > 1 or td_item.count("</b>") > 1:
# multiply <b></b> in <td></td> case.
# 1. remove all <b></b>
td_item = td_item.replace("<b>", "").replace("</b>", "")
# 2. replace <tb> -> <tb><b>, </tb> -> </b></tb>.
td_item = td_item.replace("<td>", "<td><b>").replace("</td>", "</b></td>")
new_td_list.append(td_item)
else:
new_td_list.append(td_item)
# 3. replace original thead part.
for td_item, new_td_item in zip(td_list, new_td_list):
thead_part = thead_part.replace(td_item, new_td_item)
return thead_part
def deal_bb(result_token):
"""
In our opinion, <b></b> always occurs in <thead></thead> text's context.
This function will find out all tokens in <thead></thead> and insert <b></b> by manual.
:param result_token:
:return:
"""
# find out <thead></thead> parts.
thead_pattern = "<thead>(.*?)</thead>"
if re.search(thead_pattern, result_token) is None:
return result_token
thead_part = re.search(thead_pattern, result_token).group()
origin_thead_part = copy.deepcopy(thead_part)
# check "rowspan" or "colspan" occur in <thead></thead> parts or not .
span_pattern = '<td rowspan="(\d)+" colspan="(\d)+">|<td colspan="(\d)+" rowspan="(\d)+">|<td rowspan="(\d)+">|<td colspan="(\d)+">'
span_iter = re.finditer(span_pattern, thead_part)
span_list = [s.group() for s in span_iter]
has_span_in_head = True if len(span_list) > 0 else False
if not has_span_in_head:
# <thead></thead> not include "rowspan" or "colspan" branch 1.
# 1. replace <td> to <td><b>, and </td> to </b></td>
# 2. it is possible to predict text include <b> or </b> by Text-line recognition,
# so we replace <b><b> to <b>, and </b></b> to </b>
thead_part = (
thead_part.replace("<td>", "<td><b>")
.replace("</td>", "</b></td>")
.replace("<b><b>", "<b>")
.replace("</b></b>", "</b>")
)
else:
# <thead></thead> include "rowspan" or "colspan" branch 2.
# Firstly, we deal rowspan or colspan cases.
# 1. replace > to ><b>
# 2. replace </td> to </b></td>
# 3. it is possible to predict text include <b> or </b> by Text-line recognition,
# so we replace <b><b> to <b>, and </b><b> to </b>
# Secondly, deal ordinary cases like branch 1
# replace ">" to "<b>"
replaced_span_list = []
for sp in span_list:
replaced_span_list.append(sp.replace(">", "><b>"))
for sp, rsp in zip(span_list, replaced_span_list):
thead_part = thead_part.replace(sp, rsp)
# replace "</td>" to "</b></td>"
thead_part = thead_part.replace("</td>", "</b></td>")
# remove duplicated <b> by re.sub
mb_pattern = "(<b>)+"
single_b_string = "<b>"
thead_part = re.sub(mb_pattern, single_b_string, thead_part)
mgb_pattern = "(</b>)+"
single_gb_string = "</b>"
thead_part = re.sub(mgb_pattern, single_gb_string, thead_part)
# ordinary cases like branch 1
thead_part = thead_part.replace("<td>", "<td><b>").replace("<b><b>", "<b>")
# convert <tb><b></b></tb> back to <tb></tb>, empty cell has no <b></b>.
# but space cell(<tb> </tb>) is suitable for <td><b> </b></td>
thead_part = thead_part.replace("<td><b></b></td>", "<td></td>")
# deal with duplicated <b></b>
thead_part = deal_duplicate_bb(thead_part)
# deal with isolate span tokens, which causes by wrong predict by structure prediction.
# eg.PMC5994107_011_00.png
thead_part = deal_isolate_span(thead_part)
# replace original result with new thead part.
result_token = result_token.replace(origin_thead_part, thead_part)
return result_token
def deal_eb_token(master_token):
"""
post process with <eb></eb>, <eb1></eb1>, ...
emptyBboxTokenDict = {
"[]": '<eb></eb>',
"[' ']": '<eb1></eb1>',
"['<b>', ' ', '</b>']": '<eb2></eb2>',
"['\\u2028', '\\u2028']": '<eb3></eb3>',
"['<sup>', ' ', '</sup>']": '<eb4></eb4>',
"['<b>', '</b>']": '<eb5></eb5>',
"['<i>', ' ', '</i>']": '<eb6></eb6>',
"['<b>', '<i>', '</i>', '</b>']": '<eb7></eb7>',
"['<b>', '<i>', ' ', '</i>', '</b>']": '<eb8></eb8>',
"['<i>', '</i>']": '<eb9></eb9>',
"['<b>', ' ', '\\u2028', ' ', '\\u2028', ' ', '</b>']": '<eb10></eb10>',
}
:param master_token:
:return:
"""
master_token = master_token.replace("<eb></eb>", "<td></td>")
master_token = master_token.replace("<eb1></eb1>", "<td> </td>")
master_token = master_token.replace("<eb2></eb2>", "<td><b> </b></td>")
master_token = master_token.replace("<eb3></eb3>", "<td>\u2028\u2028</td>")
master_token = master_token.replace("<eb4></eb4>", "<td><sup> </sup></td>")
master_token = master_token.replace("<eb5></eb5>", "<td><b></b></td>")
master_token = master_token.replace("<eb6></eb6>", "<td><i> </i></td>")
master_token = master_token.replace("<eb7></eb7>", "<td><b><i></i></b></td>")
master_token = master_token.replace("<eb8></eb8>", "<td><b><i> </i></b></td>")
master_token = master_token.replace("<eb9></eb9>", "<td><i></i></td>")
master_token = master_token.replace(
"<eb10></eb10>", "<td><b> \u2028 \u2028 </b></td>"
)
return master_token
def distance(box_1, box_2):
x1, y1, x2, y2 = box_1
x3, y3, x4, y4 = box_2
dis = abs(x3 - x1) + abs(y3 - y1) + abs(x4 - x2) + abs(y4 - y2)
dis_2 = abs(x3 - x1) + abs(y3 - y1)
dis_3 = abs(x4 - x2) + abs(y4 - y2)
return dis + min(dis_2, dis_3)
def compute_iou(rec1, rec2):
"""
computing IoU
:param rec1: (y0, x0, y1, x1), which reflects
(top, left, bottom, right)
:param rec2: (y0, x0, y1, x1)
:return: scala value of IoU
"""
# computing area of each rectangles
S_rec1 = (rec1[2] - rec1[0]) * (rec1[3] - rec1[1])
S_rec2 = (rec2[2] - rec2[0]) * (rec2[3] - rec2[1])
# computing the sum_area
sum_area = S_rec1 + S_rec2
# find the each edge of intersect rectangle
left_line = max(rec1[1], rec2[1])
right_line = min(rec1[3], rec2[3])
top_line = max(rec1[0], rec2[0])
bottom_line = min(rec1[2], rec2[2])
# judge if there is an intersect
if left_line >= right_line or top_line >= bottom_line:
return 0.0
intersect = (right_line - left_line) * (bottom_line - top_line)
return (intersect / (sum_area - intersect)) * 1.0

@ -0,0 +1,15 @@
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .table_structure import TableStructurer
from .table_structure_unitable import TableStructureUnitable

@ -0,0 +1,58 @@
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import time
from typing import Any, Dict
import numpy as np
from .utils import OrtInferSession, TableLabelDecode, TablePreprocess
class TableStructurer:
def __init__(self, config: Dict[str, Any]):
self.preprocess_op = TablePreprocess()
self.session = OrtInferSession(config)
self.character = self.session.get_metadata()
self.postprocess_op = TableLabelDecode(self.character)
def __call__(self, img):
starttime = time.time()
data = {"image": img}
data = self.preprocess_op(data)
img = data[0]
if img is None:
return None, 0
img = np.expand_dims(img, axis=0)
img = img.copy()
outputs = self.session([img])
preds = {"loc_preds": outputs[0], "structure_probs": outputs[1]}
shape_list = np.expand_dims(data[-1], axis=0)
post_result = self.postprocess_op(preds, [shape_list])
bbox_list = post_result["bbox_batch_list"][0]
structure_str_list = post_result["structure_batch_list"][0]
structure_str_list = structure_str_list[0]
structure_str_list = (
["<html>", "<body>", "<table>"]
+ structure_str_list
+ ["</table>", "</body>", "</html>"]
)
elapse = time.time() - starttime
return structure_str_list, bbox_list, elapse

@ -0,0 +1,229 @@
import re
import time
import cv2
import numpy as np
import torch
from PIL import Image
from tokenizers import Tokenizer
from torchvision import transforms
from .unitable_modules import Encoder, GPTFastDecoder
IMG_SIZE = 448
EOS_TOKEN = "<eos>"
BBOX_TOKENS = [f"bbox-{i}" for i in range(IMG_SIZE + 1)]
HTML_BBOX_HTML_TOKENS = [
"<td></td>",
"<td>[",
"]</td>",
"<td",
">[",
"></td>",
"<tr>",
"</tr>",
"<tbody>",
"</tbody>",
"<thead>",
"</thead>",
' rowspan="2"',
' rowspan="3"',
' rowspan="4"',
' rowspan="5"',
' rowspan="6"',
' rowspan="7"',
' rowspan="8"',
' rowspan="9"',
' rowspan="10"',
' rowspan="11"',
' rowspan="12"',
' rowspan="13"',
' rowspan="14"',
' rowspan="15"',
' rowspan="16"',
' rowspan="17"',
' rowspan="18"',
' rowspan="19"',
' colspan="2"',
' colspan="3"',
' colspan="4"',
' colspan="5"',
' colspan="6"',
' colspan="7"',
' colspan="8"',
' colspan="9"',
' colspan="10"',
' colspan="11"',
' colspan="12"',
' colspan="13"',
' colspan="14"',
' colspan="15"',
' colspan="16"',
' colspan="17"',
' colspan="18"',
' colspan="19"',
' colspan="25"',
]
VALID_HTML_BBOX_TOKENS = [EOS_TOKEN] + HTML_BBOX_HTML_TOKENS + BBOX_TOKENS
TASK_TOKENS = [
"[table]",
"[html]",
"[cell]",
"[bbox]",
"[cell+bbox]",
"[html+bbox]",
]
class TableStructureUnitable:
def __init__(self, config):
# encoder_path: str, decoder_path: str, vocab_path: str, device: str
vocab_path = config["model_path"]["vocab"]
encoder_path = config["model_path"]["encoder"]
decoder_path = config["model_path"]["decoder"]
device = config.get("device", "cuda:0") if config["use_cuda"] else "cpu"
self.vocab = Tokenizer.from_file(vocab_path)
self.token_white_list = [
self.vocab.token_to_id(i) for i in VALID_HTML_BBOX_TOKENS
]
self.bbox_token_ids = set(self.vocab.token_to_id(i) for i in BBOX_TOKENS)
self.bbox_close_html_token = self.vocab.token_to_id("]</td>")
self.prefix_token_id = self.vocab.token_to_id("[html+bbox]")
self.eos_id = self.vocab.token_to_id(EOS_TOKEN)
self.max_seq_len = 1024
self.device = device
self.img_size = IMG_SIZE
# init encoder
encoder_state_dict = torch.load(encoder_path, map_location=device)
self.encoder = Encoder()
self.encoder.load_state_dict(encoder_state_dict)
self.encoder.eval().to(device)
# init decoder
decoder_state_dict = torch.load(decoder_path, map_location=device)
self.decoder = GPTFastDecoder()
self.decoder.load_state_dict(decoder_state_dict)
self.decoder.eval().to(device)
# define img transform
self.transform = transforms.Compose(
[
transforms.Resize((448, 448)),
transforms.ToTensor(),
transforms.Normalize(
mean=[0.86597056, 0.88463002, 0.87491087],
std=[0.20686628, 0.18201602, 0.18485524],
),
]
)
@torch.inference_mode()
def __call__(self, image: np.ndarray):
start_time = time.time()
ori_h, ori_w = image.shape[:2]
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
image = Image.fromarray(image)
image = self.transform(image).unsqueeze(0).to(self.device)
self.decoder.setup_caches(
max_batch_size=1,
max_seq_length=self.max_seq_len,
dtype=image.dtype,
device=self.device,
)
context = (
torch.tensor([self.prefix_token_id], dtype=torch.int32)
.repeat(1, 1)
.to(self.device)
)
eos_id_tensor = torch.tensor(self.eos_id, dtype=torch.int32).to(self.device)
memory = self.encoder(image)
context = self.loop_decode(context, eos_id_tensor, memory)
bboxes, html_tokens = self.decode_tokens(context)
bboxes = bboxes.astype(np.float32)
# rescale boxes
scale_h = ori_h / self.img_size
scale_w = ori_w / self.img_size
bboxes[:, 0::2] *= scale_w # 缩放 x 坐标
bboxes[:, 1::2] *= scale_h # 缩放 y 坐标
bboxes[:, 0::2] = np.clip(bboxes[:, 0::2], 0, ori_w - 1)
bboxes[:, 1::2] = np.clip(bboxes[:, 1::2], 0, ori_h - 1)
structure_str_list = (
["<html>", "<body>", "<table>"]
+ html_tokens
+ ["</table>", "</body>", "</html>"]
)
return structure_str_list, bboxes, time.time() - start_time
def decode_tokens(self, context):
pred_html = context[0]
pred_html = pred_html.detach().cpu().numpy()
pred_html = self.vocab.decode(pred_html, skip_special_tokens=False)
seq = pred_html.split("<eos>")[0]
token_black_list = ["<eos>", "<pad>", *TASK_TOKENS]
for i in token_black_list:
seq = seq.replace(i, "")
tr_pattern = re.compile(r"<tr>(.*?)</tr>", re.DOTALL)
td_pattern = re.compile(r"<td(.*?)>(.*?)</td>", re.DOTALL)
bbox_pattern = re.compile(r"\[ bbox-(\d+) bbox-(\d+) bbox-(\d+) bbox-(\d+) \]")
decoded_list = []
bbox_coords = []
# 查找所有的 <tr> 标签
for tr_match in tr_pattern.finditer(pred_html):
tr_content = tr_match.group(1)
decoded_list.append("<tr>")
# 查找所有的 <td> 标签
for td_match in td_pattern.finditer(tr_content):
td_attrs = td_match.group(1).strip()
td_content = td_match.group(2).strip()
if td_attrs:
decoded_list.append("<td")
# 可能同时存在行列合并,需要都添加
attrs_list = td_attrs.split()
for attr in attrs_list:
decoded_list.append(" " + attr)
decoded_list.append(">")
decoded_list.append("</td>")
else:
decoded_list.append("<td></td>")
# 查找 bbox 坐标
bbox_match = bbox_pattern.search(td_content)
if bbox_match:
xmin, ymin, xmax, ymax = map(int, bbox_match.groups())
# 将坐标转换为从左上角开始顺时针到左下角的点的坐标
coords = np.array([xmin, ymin, xmax, ymin, xmax, ymax, xmin, ymax])
bbox_coords.append(coords)
else:
# 填充占位的bbox保证后续流程统一
bbox_coords.append(np.array([0, 0, 0, 0, 0, 0, 0, 0]))
decoded_list.append("</tr>")
bbox_coords_array = np.array(bbox_coords)
return bbox_coords_array, decoded_list
def loop_decode(self, context, eos_id_tensor, memory):
box_token_count = 0
for _ in range(self.max_seq_len):
eos_flag = (context == eos_id_tensor).any(dim=1)
if torch.all(eos_flag):
break
next_tokens = self.decoder(memory, context)
if next_tokens[0] in self.bbox_token_ids:
box_token_count += 1
if box_token_count > 4:
next_tokens = torch.tensor(
[self.bbox_close_html_token], dtype=torch.int32
)
box_token_count = 0
context = torch.cat([context, next_tokens], dim=1)
return context

@ -0,0 +1,911 @@
from dataclasses import dataclass
from functools import partial
from typing import Optional
import torch
import torch.nn as nn
from torch import Tensor
from torch.nn import functional as F
from torch.nn.modules.transformer import _get_activation_fn
TOKEN_WHITE_LIST = [
1,
12,
13,
14,
15,
16,
17,
18,
19,
20,
21,
22,
23,
24,
25,
26,
27,
28,
29,
30,
31,
32,
33,
34,
35,
36,
37,
38,
39,
40,
41,
42,
43,
44,
45,
46,
47,
48,
49,
50,
51,
52,
53,
54,
55,
56,
57,
58,
59,
60,
61,
62,
63,
64,
65,
66,
67,
68,
69,
70,
71,
72,
73,
74,
75,
76,
77,
78,
79,
80,
81,
82,
83,
84,
85,
86,
87,
88,
89,
90,
91,
92,
93,
94,
95,
96,
97,
98,
99,
100,
101,
102,
103,
104,
105,
106,
107,
108,
109,
110,
111,
112,
113,
114,
115,
116,
117,
118,
119,
120,
121,
122,
123,
124,
125,
126,
127,
128,
129,
130,
131,
132,
133,
134,
135,
136,
137,
138,
139,
140,
141,
142,
143,
144,
145,
146,
147,
148,
149,
150,
151,
152,
153,
154,
155,
156,
157,
158,
159,
160,
161,
162,
163,
164,
165,
166,
167,
168,
169,
170,
171,
172,
173,
174,
175,
176,
177,
178,
179,
180,
181,
182,
183,
184,
185,
186,
187,
188,
189,
190,
191,
192,
193,
194,
195,
196,
197,
198,
199,
200,
201,
202,
203,
204,
205,
206,
207,
208,
209,
210,
211,
212,
213,
214,
215,
216,
217,
218,
219,
220,
221,
222,
223,
224,
225,
226,
227,
228,
229,
230,
231,
232,
233,
234,
235,
236,
237,
238,
239,
240,
241,
242,
243,
244,
245,
246,
247,
248,
249,
250,
251,
252,
253,
254,
255,
256,
257,
258,
259,
260,
261,
262,
263,
264,
265,
266,
267,
268,
269,
270,
271,
272,
273,
274,
275,
276,
277,
278,
279,
280,
281,
282,
283,
284,
285,
286,
287,
288,
289,
290,
291,
292,
293,
294,
295,
296,
297,
298,
299,
300,
301,
302,
303,
304,
305,
306,
307,
308,
309,
310,
311,
312,
313,
314,
315,
316,
317,
318,
319,
320,
321,
322,
323,
324,
325,
326,
327,
328,
329,
330,
331,
332,
333,
334,
335,
336,
337,
338,
339,
340,
341,
342,
343,
344,
345,
346,
347,
348,
349,
350,
351,
352,
353,
354,
355,
356,
357,
358,
359,
360,
361,
362,
363,
364,
365,
366,
367,
368,
369,
370,
371,
372,
373,
374,
375,
376,
377,
378,
379,
380,
381,
382,
383,
384,
385,
386,
387,
388,
389,
390,
391,
392,
393,
394,
395,
396,
397,
398,
399,
400,
401,
402,
403,
404,
405,
406,
407,
408,
409,
410,
411,
412,
413,
414,
415,
416,
417,
418,
419,
420,
421,
422,
423,
424,
425,
426,
427,
428,
429,
430,
431,
432,
433,
434,
435,
436,
437,
438,
439,
440,
441,
442,
443,
444,
445,
446,
447,
448,
449,
450,
451,
452,
453,
454,
455,
456,
457,
458,
459,
460,
461,
462,
463,
464,
465,
466,
467,
468,
469,
470,
471,
472,
473,
474,
475,
476,
477,
478,
479,
480,
481,
482,
483,
484,
485,
486,
487,
488,
489,
490,
491,
492,
493,
494,
495,
496,
497,
498,
499,
500,
501,
502,
503,
504,
505,
506,
507,
508,
509,
]
class ImgLinearBackbone(nn.Module):
def __init__(
self,
d_model: int,
patch_size: int,
in_chan: int = 3,
) -> None:
super().__init__()
self.conv_proj = nn.Conv2d(
in_chan,
out_channels=d_model,
kernel_size=patch_size,
stride=patch_size,
)
self.d_model = d_model
def forward(self, x: Tensor) -> Tensor:
x = self.conv_proj(x)
x = x.flatten(start_dim=-2).transpose(1, 2)
return x
class Encoder(nn.Module):
def __init__(self) -> None:
super().__init__()
self.patch_size = 16
self.d_model = 768
self.dropout = 0
self.activation = "gelu"
self.norm_first = True
self.ff_ratio = 4
self.nhead = 12
self.max_seq_len = 1024
self.n_encoder_layer = 12
encoder_layer = nn.TransformerEncoderLayer(
self.d_model,
nhead=self.nhead,
dim_feedforward=self.ff_ratio * self.d_model,
dropout=self.dropout,
activation=self.activation,
batch_first=True,
norm_first=self.norm_first,
)
norm_layer = partial(nn.LayerNorm, eps=1e-6)
self.norm = norm_layer(self.d_model)
self.backbone = ImgLinearBackbone(
d_model=self.d_model, patch_size=self.patch_size
)
self.pos_embed = PositionEmbedding(
max_seq_len=self.max_seq_len, d_model=self.d_model, dropout=self.dropout
)
self.encoder = nn.TransformerEncoder(
encoder_layer, num_layers=self.n_encoder_layer, enable_nested_tensor=False
)
def forward(self, x: Tensor) -> Tensor:
src_feature = self.backbone(x)
src_feature = self.pos_embed(src_feature)
memory = self.encoder(src_feature)
memory = self.norm(memory)
return memory
class PositionEmbedding(nn.Module):
def __init__(self, max_seq_len: int, d_model: int, dropout: float) -> None:
super().__init__()
self.embedding = nn.Embedding(max_seq_len, d_model)
self.dropout = nn.Dropout(dropout)
def forward(self, x: Tensor, input_pos: Optional[Tensor] = None) -> Tensor:
# assume x is batch first
if input_pos is None:
_pos = torch.arange(x.shape[1], device=x.device)
else:
_pos = input_pos
out = self.embedding(_pos)
return self.dropout(out + x)
class TokenEmbedding(nn.Module):
def __init__(
self,
vocab_size: int,
d_model: int,
padding_idx: int,
) -> None:
super().__init__()
assert vocab_size > 0
self.embedding = nn.Embedding(vocab_size, d_model, padding_idx=padding_idx)
def forward(self, x: Tensor) -> Tensor:
return self.embedding(x)
def find_multiple(n: int, k: int) -> int:
if n % k == 0:
return n
return n + k - (n % k)
@dataclass
class ModelArgs:
n_layer: int = 4
n_head: int = 12
dim: int = 768
intermediate_size: int = None
head_dim: int = 64
activation: str = "gelu"
norm_first: bool = True
def __post_init__(self):
if self.intermediate_size is None:
hidden_dim = 4 * self.dim
n_hidden = int(2 * hidden_dim / 3)
self.intermediate_size = find_multiple(n_hidden, 256)
self.head_dim = self.dim // self.n_head
class KVCache(nn.Module):
def __init__(
self,
max_batch_size,
max_seq_length,
n_heads,
head_dim,
dtype=torch.bfloat16,
device="cpu",
):
super().__init__()
cache_shape = (max_batch_size, n_heads, max_seq_length, head_dim)
self.register_buffer(
"k_cache",
torch.zeros(cache_shape, dtype=dtype, device=device),
persistent=False,
)
self.register_buffer(
"v_cache",
torch.zeros(cache_shape, dtype=dtype, device=device),
persistent=False,
)
def update(self, input_pos, k_val, v_val):
# input_pos: [S], k_val: [B, H, S, D]
# assert input_pos.shape[0] == k_val.shape[2]
bs = k_val.shape[0]
k_out = self.k_cache
v_out = self.v_cache
k_out[:bs, :, input_pos] = k_val
v_out[:bs, :, input_pos] = v_val
return k_out[:bs], v_out[:bs]
class GPTFastDecoder(nn.Module):
def __init__(self) -> None:
super().__init__()
self.vocab_size = 960
self.padding_idx = 2
self.prefix_token_id = 11
self.eos_id = 1
self.max_seq_len = 1024
self.dropout = 0
self.d_model = 768
self.nhead = 12
self.activation = "gelu"
self.norm_first = True
self.n_decoder_layer = 4
config = ModelArgs(
n_layer=self.n_decoder_layer,
n_head=self.nhead,
dim=self.d_model,
intermediate_size=self.d_model * 4,
activation=self.activation,
norm_first=self.norm_first,
)
self.config = config
self.layers = nn.ModuleList(
TransformerBlock(config) for _ in range(config.n_layer)
)
self.token_embed = TokenEmbedding(
vocab_size=self.vocab_size,
d_model=self.d_model,
padding_idx=self.padding_idx,
)
self.pos_embed = PositionEmbedding(
max_seq_len=self.max_seq_len, d_model=self.d_model, dropout=self.dropout
)
self.generator = nn.Linear(self.d_model, self.vocab_size)
self.token_white_list = TOKEN_WHITE_LIST
self.mask_cache: Optional[Tensor] = None
self.max_batch_size = -1
self.max_seq_length = -1
def setup_caches(self, max_batch_size, max_seq_length, dtype, device):
for b in self.layers:
b.multihead_attn.k_cache = None
b.multihead_attn.v_cache = None
if (
self.max_seq_length >= max_seq_length
and self.max_batch_size >= max_batch_size
):
return
head_dim = self.config.dim // self.config.n_head
max_seq_length = find_multiple(max_seq_length, 8)
self.max_seq_length = max_seq_length
self.max_batch_size = max_batch_size
for b in self.layers:
b.self_attn.kv_cache = KVCache(
max_batch_size,
max_seq_length,
self.config.n_head,
head_dim,
dtype,
device,
)
b.multihead_attn.k_cache = None
b.multihead_attn.v_cache = None
self.causal_mask = torch.tril(
torch.ones(self.max_seq_length, self.max_seq_length, dtype=torch.bool)
).to(device)
def forward(self, memory: Tensor, tgt: Tensor) -> Tensor:
input_pos = torch.tensor([tgt.shape[1] - 1], device=tgt.device, dtype=torch.int)
tgt = tgt[:, -1:]
tgt_feature = self.pos_embed(self.token_embed(tgt), input_pos=input_pos)
# tgt = self.decoder(tgt_feature, memory, input_pos)
with torch.backends.cuda.sdp_kernel(
enable_flash=False, enable_mem_efficient=False, enable_math=True
):
logits = tgt_feature
tgt_mask = self.causal_mask[None, None, input_pos]
for i, layer in enumerate(self.layers):
logits = layer(logits, memory, input_pos=input_pos, tgt_mask=tgt_mask)
# return output
logits = self.generator(logits)[:, -1, :]
total = set([i for i in range(logits.shape[-1])])
black_list = list(total.difference(set(self.token_white_list)))
logits[..., black_list] = -1e9
probs = F.softmax(logits, dim=-1)
_, next_tokens = probs.topk(1)
return next_tokens
class TransformerBlock(nn.Module):
def __init__(self, config: ModelArgs) -> None:
super().__init__()
self.self_attn = Attention(config)
self.multihead_attn = CrossAttention(config)
layer_norm_eps = 1e-5
d_model = config.dim
dim_feedforward = config.intermediate_size
self.linear1 = nn.Linear(d_model, dim_feedforward)
self.linear2 = nn.Linear(dim_feedforward, d_model)
self.norm_first = config.norm_first
self.norm1 = nn.LayerNorm(d_model, eps=layer_norm_eps)
self.norm2 = nn.LayerNorm(d_model, eps=layer_norm_eps)
self.norm3 = nn.LayerNorm(d_model, eps=layer_norm_eps)
self.activation = _get_activation_fn(config.activation)
def forward(
self,
tgt: Tensor,
memory: Tensor,
tgt_mask: Tensor,
input_pos: Tensor,
) -> Tensor:
if self.norm_first:
x = tgt
x = x + self.self_attn(self.norm1(x), tgt_mask, input_pos)
x = x + self.multihead_attn(self.norm2(x), memory)
x = x + self._ff_block(self.norm3(x))
else:
x = tgt
x = self.norm1(x + self.self_attn(x, tgt_mask, input_pos))
x = self.norm2(x + self.multihead_attn(x, memory))
x = self.norm3(x + self._ff_block(x))
return x
def _ff_block(self, x: Tensor) -> Tensor:
x = self.linear2(self.activation(self.linear1(x)))
return x
class Attention(nn.Module):
def __init__(self, config: ModelArgs):
super().__init__()
assert config.dim % config.n_head == 0
# key, query, value projections for all heads, but in a batch
self.wqkv = nn.Linear(config.dim, 3 * config.dim)
self.wo = nn.Linear(config.dim, config.dim)
self.kv_cache: Optional[KVCache] = None
self.n_head = config.n_head
self.head_dim = config.head_dim
self.dim = config.dim
def forward(
self,
x: Tensor,
mask: Tensor,
input_pos: Optional[Tensor] = None,
) -> Tensor:
bsz, seqlen, _ = x.shape
kv_size = self.n_head * self.head_dim
q, k, v = self.wqkv(x).split([self.dim, kv_size, kv_size], dim=-1)
q = q.view(bsz, seqlen, self.n_head, self.head_dim)
k = k.view(bsz, seqlen, self.n_head, self.head_dim)
v = v.view(bsz, seqlen, self.n_head, self.head_dim)
q, k, v = map(lambda x: x.transpose(1, 2), (q, k, v))
if self.kv_cache is not None:
k, v = self.kv_cache.update(input_pos, k, v)
y = F.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0)
y = y.transpose(1, 2).contiguous().view(bsz, seqlen, self.dim)
y = self.wo(y)
return y
class CrossAttention(nn.Module):
def __init__(self, config: ModelArgs):
super().__init__()
assert config.dim % config.n_head == 0
self.query = nn.Linear(config.dim, config.dim)
self.key = nn.Linear(config.dim, config.dim)
self.value = nn.Linear(config.dim, config.dim)
self.out = nn.Linear(config.dim, config.dim)
self.k_cache = None
self.v_cache = None
self.n_head = config.n_head
self.head_dim = config.head_dim
def get_kv(self, xa: torch.Tensor):
if self.k_cache is not None and self.v_cache is not None:
return self.k_cache, self.v_cache
k = self.key(xa)
v = self.value(xa)
# Reshape for correct format
batch_size, source_seq_len, _ = k.shape
k = k.view(batch_size, source_seq_len, self.n_head, self.head_dim)
v = v.view(batch_size, source_seq_len, self.n_head, self.head_dim)
if self.k_cache is None:
self.k_cache = k
if self.v_cache is None:
self.v_cache = v
return k, v
def forward(
self,
x: Tensor,
xa: Tensor,
):
q = self.query(x)
batch_size, target_seq_len, _ = q.shape
q = q.view(batch_size, target_seq_len, self.n_head, self.head_dim)
k, v = self.get_kv(xa)
q, k, v = map(lambda x: x.transpose(1, 2), (q, k, v))
wv = F.scaled_dot_product_attention(
query=q,
key=k,
value=v,
is_causal=False,
)
wv = wv.transpose(1, 2).reshape(
batch_size,
target_seq_len,
self.n_head * self.head_dim,
)
return self.out(wv)

@ -0,0 +1,544 @@
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# -*- encoding: utf-8 -*-
# @Author: SWHL
# @Contact: liekkaskono@163.com
import os
import platform
import traceback
from enum import Enum
from pathlib import Path
from typing import Any, Dict, List, Tuple, Union
import cv2
import numpy as np
from onnxruntime import (
GraphOptimizationLevel,
InferenceSession,
SessionOptions,
get_available_providers,
get_device,
)
from rapid_table.utils import Logger
class EP(Enum):
CPU_EP = "CPUExecutionProvider"
CUDA_EP = "CUDAExecutionProvider"
DIRECTML_EP = "DmlExecutionProvider"
class OrtInferSession:
def __init__(self, config: Dict[str, Any]):
self.logger = Logger(logger_name=__name__).get_log()
model_path = config.get("model_path", None)
self._verify_model(model_path)
self.cfg_use_cuda = config.get("use_cuda", None)
self.cfg_use_dml = config.get("use_dml", None)
self.had_providers: List[str] = get_available_providers()
EP_list = self._get_ep_list()
sess_opt = self._init_sess_opts(config)
self.session = InferenceSession(
model_path,
sess_options=sess_opt,
providers=EP_list,
)
self._verify_providers()
@staticmethod
def _init_sess_opts(config: Dict[str, Any]) -> SessionOptions:
sess_opt = SessionOptions()
sess_opt.log_severity_level = 4
sess_opt.enable_cpu_mem_arena = False
sess_opt.graph_optimization_level = GraphOptimizationLevel.ORT_ENABLE_ALL
cpu_nums = os.cpu_count()
intra_op_num_threads = config.get("intra_op_num_threads", -1)
if intra_op_num_threads != -1 and 1 <= intra_op_num_threads <= cpu_nums:
sess_opt.intra_op_num_threads = intra_op_num_threads
inter_op_num_threads = config.get("inter_op_num_threads", -1)
if inter_op_num_threads != -1 and 1 <= inter_op_num_threads <= cpu_nums:
sess_opt.inter_op_num_threads = inter_op_num_threads
return sess_opt
def get_metadata(self, key: str = "character") -> list:
meta_dict = self.session.get_modelmeta().custom_metadata_map
content_list = meta_dict[key].splitlines()
return content_list
def _get_ep_list(self) -> List[Tuple[str, Dict[str, Any]]]:
cpu_provider_opts = {
"arena_extend_strategy": "kSameAsRequested",
}
EP_list = [(EP.CPU_EP.value, cpu_provider_opts)]
cuda_provider_opts = {
"device_id": 0,
"arena_extend_strategy": "kNextPowerOfTwo",
"cudnn_conv_algo_search": "EXHAUSTIVE",
"do_copy_in_default_stream": True,
}
self.use_cuda = self._check_cuda()
if self.use_cuda:
EP_list.insert(0, (EP.CUDA_EP.value, cuda_provider_opts))
self.use_directml = self._check_dml()
if self.use_directml:
self.logger.info(
"Windows 10 or above detected, try to use DirectML as primary provider"
)
directml_options = (
cuda_provider_opts if self.use_cuda else cpu_provider_opts
)
EP_list.insert(0, (EP.DIRECTML_EP.value, directml_options))
return EP_list
def _check_cuda(self) -> bool:
if not self.cfg_use_cuda:
return False
cur_device = get_device()
if cur_device == "GPU" and EP.CUDA_EP.value in self.had_providers:
return True
self.logger.warning(
"%s is not in available providers (%s). Use %s inference by default.",
EP.CUDA_EP.value,
self.had_providers,
self.had_providers[0],
)
self.logger.info("!!!Recommend to use rapidocr_paddle for inference on GPU.")
self.logger.info(
"(For reference only) If you want to use GPU acceleration, you must do:"
)
self.logger.info(
"First, uninstall all onnxruntime pakcages in current environment."
)
self.logger.info(
"Second, install onnxruntime-gpu by `pip install onnxruntime-gpu`."
)
self.logger.info(
"\tNote the onnxruntime-gpu version must match your cuda and cudnn version."
)
self.logger.info(
"\tYou can refer this link: https://onnxruntime.ai/docs/execution-providers/CUDA-EP.html"
)
self.logger.info(
"Third, ensure %s is in available providers list. e.g. ['CUDAExecutionProvider', 'CPUExecutionProvider']",
EP.CUDA_EP.value,
)
return False
def _check_dml(self) -> bool:
if not self.cfg_use_dml:
return False
cur_os = platform.system()
if cur_os != "Windows":
self.logger.warning(
"DirectML is only supported in Windows OS. The current OS is %s. Use %s inference by default.",
cur_os,
self.had_providers[0],
)
return False
cur_window_version = int(platform.release().split(".")[0])
if cur_window_version < 10:
self.logger.warning(
"DirectML is only supported in Windows 10 and above OS. The current Windows version is %s. Use %s inference by default.",
cur_window_version,
self.had_providers[0],
)
return False
if EP.DIRECTML_EP.value in self.had_providers:
return True
self.logger.warning(
"%s is not in available providers (%s). Use %s inference by default.",
EP.DIRECTML_EP.value,
self.had_providers,
self.had_providers[0],
)
self.logger.info("If you want to use DirectML acceleration, you must do:")
self.logger.info(
"First, uninstall all onnxruntime pakcages in current environment."
)
self.logger.info(
"Second, install onnxruntime-directml by `pip install onnxruntime-directml`"
)
self.logger.info(
"Third, ensure %s is in available providers list. e.g. ['DmlExecutionProvider', 'CPUExecutionProvider']",
EP.DIRECTML_EP.value,
)
return False
def _verify_providers(self):
session_providers = self.session.get_providers()
first_provider = session_providers[0]
if self.use_cuda and first_provider != EP.CUDA_EP.value:
self.logger.warning(
"%s is not avaiable for current env, the inference part is automatically shifted to be executed under %s.",
EP.CUDA_EP.value,
first_provider,
)
if self.use_directml and first_provider != EP.DIRECTML_EP.value:
self.logger.warning(
"%s is not available for current env, the inference part is automatically shifted to be executed under %s.",
EP.DIRECTML_EP.value,
first_provider,
)
def __call__(self, input_content: List[np.ndarray]) -> np.ndarray:
input_dict = dict(zip(self.get_input_names(), input_content))
try:
return self.session.run(None, input_dict)
except Exception as e:
error_info = traceback.format_exc()
raise ONNXRuntimeError(error_info) from e
def get_input_names(self) -> List[str]:
return [v.name for v in self.session.get_inputs()]
def get_output_names(self) -> List[str]:
return [v.name for v in self.session.get_outputs()]
def get_character_list(self, key: str = "character") -> List[str]:
meta_dict = self.session.get_modelmeta().custom_metadata_map
return meta_dict[key].splitlines()
def have_key(self, key: str = "character") -> bool:
meta_dict = self.session.get_modelmeta().custom_metadata_map
if key in meta_dict.keys():
return True
return False
@staticmethod
def _verify_model(model_path: Union[str, Path, None]):
if model_path is None:
raise ValueError("model_path is None!")
model_path = Path(model_path)
if not model_path.exists():
raise FileNotFoundError(f"{model_path} does not exists.")
if not model_path.is_file():
raise FileExistsError(f"{model_path} is not a file.")
class ONNXRuntimeError(Exception):
pass
class TableLabelDecode:
def __init__(self, dict_character, merge_no_span_structure=True, **kwargs):
if merge_no_span_structure:
if "<td></td>" not in dict_character:
dict_character.append("<td></td>")
if "<td>" in dict_character:
dict_character.remove("<td>")
dict_character = self.add_special_char(dict_character)
self.dict = {}
for i, char in enumerate(dict_character):
self.dict[char] = i
self.character = dict_character
self.td_token = ["<td>", "<td", "<td></td>"]
def __call__(self, preds, batch=None):
structure_probs = preds["structure_probs"]
bbox_preds = preds["loc_preds"]
shape_list = batch[-1]
result = self.decode(structure_probs, bbox_preds, shape_list)
if len(batch) == 1: # only contains shape
return result
label_decode_result = self.decode_label(batch)
return result, label_decode_result
def decode(self, structure_probs, bbox_preds, shape_list):
"""convert text-label into text-index."""
ignored_tokens = self.get_ignored_tokens()
end_idx = self.dict[self.end_str]
structure_idx = structure_probs.argmax(axis=2)
structure_probs = structure_probs.max(axis=2)
structure_batch_list = []
bbox_batch_list = []
batch_size = len(structure_idx)
for batch_idx in range(batch_size):
structure_list = []
bbox_list = []
score_list = []
for idx in range(len(structure_idx[batch_idx])):
char_idx = int(structure_idx[batch_idx][idx])
if idx > 0 and char_idx == end_idx:
break
if char_idx in ignored_tokens:
continue
text = self.character[char_idx]
if text in self.td_token:
bbox = bbox_preds[batch_idx, idx]
bbox = self._bbox_decode(bbox, shape_list[batch_idx])
bbox_list.append(bbox)
structure_list.append(text)
score_list.append(structure_probs[batch_idx, idx])
structure_batch_list.append([structure_list, np.mean(score_list)])
bbox_batch_list.append(np.array(bbox_list))
result = {
"bbox_batch_list": bbox_batch_list,
"structure_batch_list": structure_batch_list,
}
return result
def decode_label(self, batch):
"""convert text-label into text-index."""
structure_idx = batch[1]
gt_bbox_list = batch[2]
shape_list = batch[-1]
ignored_tokens = self.get_ignored_tokens()
end_idx = self.dict[self.end_str]
structure_batch_list = []
bbox_batch_list = []
batch_size = len(structure_idx)
for batch_idx in range(batch_size):
structure_list = []
bbox_list = []
for idx in range(len(structure_idx[batch_idx])):
char_idx = int(structure_idx[batch_idx][idx])
if idx > 0 and char_idx == end_idx:
break
if char_idx in ignored_tokens:
continue
structure_list.append(self.character[char_idx])
bbox = gt_bbox_list[batch_idx][idx]
if bbox.sum() != 0:
bbox = self._bbox_decode(bbox, shape_list[batch_idx])
bbox_list.append(bbox)
structure_batch_list.append(structure_list)
bbox_batch_list.append(bbox_list)
result = {
"bbox_batch_list": bbox_batch_list,
"structure_batch_list": structure_batch_list,
}
return result
def _bbox_decode(self, bbox, shape):
h, w = shape[:2]
bbox[0::2] *= w
bbox[1::2] *= h
return bbox
def get_ignored_tokens(self):
beg_idx = self.get_beg_end_flag_idx("beg")
end_idx = self.get_beg_end_flag_idx("end")
return [beg_idx, end_idx]
def get_beg_end_flag_idx(self, beg_or_end):
if beg_or_end == "beg":
return np.array(self.dict[self.beg_str])
if beg_or_end == "end":
return np.array(self.dict[self.end_str])
raise TypeError(f"unsupport type {beg_or_end} in get_beg_end_flag_idx")
def add_special_char(self, dict_character):
self.beg_str = "sos"
self.end_str = "eos"
dict_character = [self.beg_str] + dict_character + [self.end_str]
return dict_character
class TablePreprocess:
def __init__(self):
self.table_max_len = 488
self.build_pre_process_list()
self.ops = self.create_operators()
def __call__(self, data):
"""transform"""
if self.ops is None:
self.ops = []
for op in self.ops:
data = op(data)
if data is None:
return None
return data
def create_operators(
self,
):
"""
create operators based on the config
Args:
params(list): a dict list, used to create some operators
"""
assert isinstance(
self.pre_process_list, list
), "operator config should be a list"
ops = []
for operator in self.pre_process_list:
assert (
isinstance(operator, dict) and len(operator) == 1
), "yaml format error"
op_name = list(operator)[0]
param = {} if operator[op_name] is None else operator[op_name]
op = eval(op_name)(**param)
ops.append(op)
return ops
def build_pre_process_list(self):
resize_op = {
"ResizeTableImage": {
"max_len": self.table_max_len,
}
}
pad_op = {
"PaddingTableImage": {"size": [self.table_max_len, self.table_max_len]}
}
normalize_op = {
"NormalizeImage": {
"std": [0.229, 0.224, 0.225],
"mean": [0.485, 0.456, 0.406],
"scale": "1./255.",
"order": "hwc",
}
}
to_chw_op = {"ToCHWImage": None}
keep_keys_op = {"KeepKeys": {"keep_keys": ["image", "shape"]}}
self.pre_process_list = [
resize_op,
normalize_op,
pad_op,
to_chw_op,
keep_keys_op,
]
class ResizeTableImage:
def __init__(self, max_len, resize_bboxes=False, infer_mode=False):
super(ResizeTableImage, self).__init__()
self.max_len = max_len
self.resize_bboxes = resize_bboxes
self.infer_mode = infer_mode
def __call__(self, data):
img = data["image"]
height, width = img.shape[0:2]
ratio = self.max_len / (max(height, width) * 1.0)
resize_h = int(height * ratio)
resize_w = int(width * ratio)
resize_img = cv2.resize(img, (resize_w, resize_h))
if self.resize_bboxes and not self.infer_mode:
data["bboxes"] = data["bboxes"] * ratio
data["image"] = resize_img
data["src_img"] = img
data["shape"] = np.array([height, width, ratio, ratio])
data["max_len"] = self.max_len
return data
class PaddingTableImage:
def __init__(self, size, **kwargs):
super(PaddingTableImage, self).__init__()
self.size = size
def __call__(self, data):
img = data["image"]
pad_h, pad_w = self.size
padding_img = np.zeros((pad_h, pad_w, 3), dtype=np.float32)
height, width = img.shape[0:2]
padding_img[0:height, 0:width, :] = img.copy()
data["image"] = padding_img
shape = data["shape"].tolist()
shape.extend([pad_h, pad_w])
data["shape"] = np.array(shape)
return data
class NormalizeImage:
"""normalize image such as substract mean, divide std"""
def __init__(self, scale=None, mean=None, std=None, order="chw", **kwargs):
if isinstance(scale, str):
scale = eval(scale)
self.scale = np.float32(scale if scale is not None else 1.0 / 255.0)
mean = mean if mean is not None else [0.485, 0.456, 0.406]
std = std if std is not None else [0.229, 0.224, 0.225]
shape = (3, 1, 1) if order == "chw" else (1, 1, 3)
self.mean = np.array(mean).reshape(shape).astype("float32")
self.std = np.array(std).reshape(shape).astype("float32")
def __call__(self, data):
img = np.array(data["image"])
assert isinstance(img, np.ndarray), "invalid input 'img' in NormalizeImage"
data["image"] = (img.astype("float32") * self.scale - self.mean) / self.std
return data
class ToCHWImage:
"""convert hwc image to chw image"""
def __init__(self, **kwargs):
pass
def __call__(self, data):
img = np.array(data["image"])
data["image"] = img.transpose((2, 0, 1))
return data
class KeepKeys:
def __init__(self, keep_keys, **kwargs):
self.keep_keys = keep_keys
def __call__(self, data):
data_list = []
for key in self.keep_keys:
data_list.append(data[key])
return data_list
def trans_char_ocr_res(ocr_res):
word_result = []
for res in ocr_res:
score = res[2]
for word_box, word in zip(res[3], res[4]):
word_res = []
word_res.append(word_box)
word_res.append(word)
word_res.append(score)
word_result.append(word_res)
return word_result

@ -0,0 +1,8 @@
# -*- encoding: utf-8 -*-
# @Author: SWHL
# @Contact: liekkaskono@163.com
from .download_model import DownloadModel
from .load_image import LoadImage
from .logger import Logger
from .utils import is_url
from .vis import VisTable

@ -0,0 +1,67 @@
import io
from pathlib import Path
from typing import Optional, Union
import requests
from tqdm import tqdm
from .logger import Logger
PROJECT_DIR = Path(__file__).resolve().parent.parent
DEFAULT_MODEL_DIR = PROJECT_DIR / "models"
class DownloadModel:
logger = Logger(logger_name=__name__).get_log()
@classmethod
def download(
cls,
model_full_url: Union[str, Path],
save_dir: Union[str, Path, None] = None,
save_model_name: Optional[str] = None,
) -> str:
if save_dir is None:
save_dir = DEFAULT_MODEL_DIR
save_dir.mkdir(parents=True, exist_ok=True)
if save_model_name is None:
save_model_name = Path(model_full_url).name
save_file_path = save_dir / save_model_name
if save_file_path.exists():
cls.logger.info("%s already exists", save_file_path)
return str(save_file_path)
try:
cls.logger.info("Download %s to %s", model_full_url, save_dir)
file = cls.download_as_bytes_with_progress(model_full_url, save_model_name)
cls.save_file(save_file_path, file)
except Exception as exc:
raise DownloadModelError from exc
return str(save_file_path)
@staticmethod
def download_as_bytes_with_progress(
url: Union[str, Path], name: Optional[str] = None
) -> bytes:
resp = requests.get(str(url), stream=True, allow_redirects=True, timeout=180)
total = int(resp.headers.get("content-length", 0))
bio = io.BytesIO()
with tqdm(
desc=name, total=total, unit="b", unit_scale=True, unit_divisor=1024
) as pbar:
for chunk in resp.iter_content(chunk_size=65536):
pbar.update(len(chunk))
bio.write(chunk)
return bio.getvalue()
@staticmethod
def save_file(save_path: Union[str, Path], file: bytes):
with open(save_path, "wb") as f:
f.write(file)
class DownloadModelError(Exception):
pass

@ -0,0 +1,131 @@
# -*- encoding: utf-8 -*-
# @Author: SWHL
# @Contact: liekkaskono@163.com
from io import BytesIO
from pathlib import Path
from typing import Any, Union
import cv2
import numpy as np
import requests
from PIL import Image, UnidentifiedImageError
from .utils import is_url
root_dir = Path(__file__).resolve().parent
InputType = Union[str, np.ndarray, bytes, Path, Image.Image]
class LoadImage:
def __init__(self):
pass
def __call__(self, img: InputType) -> np.ndarray:
if not isinstance(img, InputType.__args__):
raise LoadImageError(
f"The img type {type(img)} does not in {InputType.__args__}"
)
origin_img_type = type(img)
img = self.load_img(img)
img = self.convert_img(img, origin_img_type)
return img
def load_img(self, img: InputType) -> np.ndarray:
if isinstance(img, (str, Path)):
if is_url(img):
img = Image.open(requests.get(img, stream=True, timeout=60).raw)
else:
self.verify_exist(img)
img = Image.open(img)
try:
img = self.img_to_ndarray(img)
except UnidentifiedImageError as e:
raise LoadImageError(f"cannot identify image file {img}") from e
return img
if isinstance(img, bytes):
img = self.img_to_ndarray(Image.open(BytesIO(img)))
return img
if isinstance(img, np.ndarray):
return img
if isinstance(img, Image.Image):
return self.img_to_ndarray(img)
raise LoadImageError(f"{type(img)} is not supported!")
def img_to_ndarray(self, img: Image.Image) -> np.ndarray:
if img.mode == "1":
img = img.convert("L")
return np.array(img)
return np.array(img)
def convert_img(self, img: np.ndarray, origin_img_type: Any) -> np.ndarray:
if img.ndim == 2:
return cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)
if img.ndim == 3:
channel = img.shape[2]
if channel == 1:
return cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)
if channel == 2:
return self.cvt_two_to_three(img)
if channel == 3:
if issubclass(origin_img_type, (str, Path, bytes, Image.Image)):
return cv2.cvtColor(img, cv2.COLOR_RGB2BGR)
return img
if channel == 4:
return self.cvt_four_to_three(img)
raise LoadImageError(
f"The channel({channel}) of the img is not in [1, 2, 3, 4]"
)
raise LoadImageError(f"The ndim({img.ndim}) of the img is not in [2, 3]")
@staticmethod
def cvt_two_to_three(img: np.ndarray) -> np.ndarray:
"""gray + alpha → BGR"""
img_gray = img[..., 0]
img_bgr = cv2.cvtColor(img_gray, cv2.COLOR_GRAY2BGR)
img_alpha = img[..., 1]
not_a = cv2.bitwise_not(img_alpha)
not_a = cv2.cvtColor(not_a, cv2.COLOR_GRAY2BGR)
new_img = cv2.bitwise_and(img_bgr, img_bgr, mask=img_alpha)
new_img = cv2.add(new_img, not_a)
return new_img
@staticmethod
def cvt_four_to_three(img: np.ndarray) -> np.ndarray:
"""RGBA → BGR"""
r, g, b, a = cv2.split(img)
new_img = cv2.merge((b, g, r))
not_a = cv2.bitwise_not(a)
not_a = cv2.cvtColor(not_a, cv2.COLOR_GRAY2BGR)
new_img = cv2.bitwise_and(new_img, new_img, mask=a)
mean_color = np.mean(new_img)
if mean_color <= 0.0:
new_img = cv2.add(new_img, not_a)
else:
new_img = cv2.bitwise_not(new_img)
return new_img
@staticmethod
def verify_exist(file_path: Union[str, Path]):
if not Path(file_path).exists():
raise LoadImageError(f"{file_path} does not exist.")
class LoadImageError(Exception):
pass

@ -0,0 +1,37 @@
# -*- encoding: utf-8 -*-
# @Author: Jocker1212
# @Contact: xinyijianggo@gmail.com
import logging
import colorlog
class Logger:
def __init__(self, log_level=logging.DEBUG, logger_name=None):
self.logger = logging.getLogger(logger_name)
self.logger.setLevel(log_level)
self.logger.propagate = False
formatter = colorlog.ColoredFormatter(
"%(log_color)s[%(levelname)s] %(asctime)s [RapidTable] %(filename)s:%(lineno)d: %(message)s",
log_colors={
"DEBUG": "cyan",
"INFO": "green",
"WARNING": "yellow",
"ERROR": "red",
"CRITICAL": "red,bg_white",
},
)
if not self.logger.handlers:
console_handler = logging.StreamHandler()
console_handler.setFormatter(formatter)
for handler in self.logger.handlers:
self.logger.removeHandler(handler)
console_handler.setLevel(log_level)
self.logger.addHandler(console_handler)
def get_log(self):
return self.logger

@ -0,0 +1,12 @@
# -*- encoding: utf-8 -*-
# @Author: SWHL
# @Contact: liekkaskono@163.com
from urllib.parse import urlparse
def is_url(url: str) -> bool:
try:
result = urlparse(url)
return all([result.scheme, result.netloc])
except Exception:
return False

@ -0,0 +1,145 @@
# -*- encoding: utf-8 -*-
# @Author: SWHL
# @Contact: liekkaskono@163.com
import os
from pathlib import Path
from typing import Optional, Union
import cv2
import numpy as np
from .load_image import LoadImage
class VisTable:
def __init__(self):
self.load_img = LoadImage()
def __call__(
self,
img_path: Union[str, Path],
table_results,
save_html_path: Optional[str] = None,
save_drawed_path: Optional[str] = None,
save_logic_path: Optional[str] = None,
):
if save_html_path:
html_with_border = self.insert_border_style(table_results.pred_html)
self.save_html(save_html_path, html_with_border)
table_cell_bboxes = table_results.cell_bboxes
if table_cell_bboxes is None:
return None
img = self.load_img(img_path)
dims_bboxes = table_cell_bboxes.shape[1]
if dims_bboxes == 4:
drawed_img = self.draw_rectangle(img, table_cell_bboxes)
elif dims_bboxes == 8:
drawed_img = self.draw_polylines(img, table_cell_bboxes)
else:
raise ValueError("Shape of table bounding boxes is not between in 4 or 8.")
if save_drawed_path:
self.save_img(save_drawed_path, drawed_img)
if save_logic_path and table_results.logic_points:
polygons = [[box[0], box[1], box[4], box[5]] for box in table_cell_bboxes]
self.plot_rec_box_with_logic_info(
img, save_logic_path, table_results.logic_points, polygons
)
return drawed_img
def insert_border_style(self, table_html_str: str):
style_res = """<meta charset="UTF-8"><style>
table {
border-collapse: collapse;
width: 100%;
}
th, td {
border: 1px solid black;
padding: 8px;
text-align: center;
}
th {
background-color: #f2f2f2;
}
</style>"""
prefix_table, suffix_table = table_html_str.split("<body>")
html_with_border = f"{prefix_table}{style_res}<body>{suffix_table}"
return html_with_border
def plot_rec_box_with_logic_info(
self, img: np.ndarray, output_path, logic_points, sorted_polygons
):
"""
:param img_path
:param output_path
:param logic_points: [row_start,row_end,col_start,col_end]
:param sorted_polygons: [xmin,ymin,xmax,ymax]
:return:
"""
# 读取原图
img = cv2.copyMakeBorder(
img, 0, 0, 0, 100, cv2.BORDER_CONSTANT, value=[255, 255, 255]
)
# 绘制 polygons 矩形
for idx, polygon in enumerate(sorted_polygons):
x0, y0, x1, y1 = polygon[0], polygon[1], polygon[2], polygon[3]
x0 = round(x0)
y0 = round(y0)
x1 = round(x1)
y1 = round(y1)
cv2.rectangle(img, (x0, y0), (x1, y1), (0, 0, 255), 1)
# 增大字体大小和线宽
font_scale = 0.9 # 原先是0.5
thickness = 1 # 原先是1
logic_point = logic_points[idx]
cv2.putText(
img,
f"row: {logic_point[0]}-{logic_point[1]}",
(x0 + 3, y0 + 8),
cv2.FONT_HERSHEY_PLAIN,
font_scale,
(0, 0, 255),
thickness,
)
cv2.putText(
img,
f"col: {logic_point[2]}-{logic_point[3]}",
(x0 + 3, y0 + 18),
cv2.FONT_HERSHEY_PLAIN,
font_scale,
(0, 0, 255),
thickness,
)
os.makedirs(os.path.dirname(output_path), exist_ok=True)
# 保存绘制后的图像
self.save_img(output_path, img)
@staticmethod
def draw_rectangle(img: np.ndarray, boxes: np.ndarray) -> np.ndarray:
img_copy = img.copy()
for box in boxes.astype(int):
x1, y1, x2, y2 = box
cv2.rectangle(img_copy, (x1, y1), (x2, y2), (255, 0, 0), 2)
return img_copy
@staticmethod
def draw_polylines(img: np.ndarray, points) -> np.ndarray:
img_copy = img.copy()
for point in points.astype(int):
point = point.reshape(4, 2)
cv2.polylines(img_copy, [point.astype(int)], True, (255, 0, 0), 2)
return img_copy
@staticmethod
def save_img(save_path: Union[str, Path], img: np.ndarray):
cv2.imwrite(str(save_path), img)
@staticmethod
def save_html(save_path: Union[str, Path], html: str):
with open(save_path, "w", encoding="utf-8") as f:
f.write(html)

@ -0,0 +1,18 @@
from magic_pdf.model.doc_analyze_by_custom_model import doc_analyze
from magic_pdf.data.read_api import read_local_images
from markdownify import markdownify as md
import re
# proc
## Create Dataset Instance
input_file = "/mnt/research/PaddleOCR/pdf2md_pipeline/s4_content_recognition/all_layouts/207.jpg"
ds = read_local_images(input_file)[0]
x = ds.apply(doc_analyze, ocr=True)
x = x.pipe_ocr_mode(None)
html = x.get_markdown(None)
content = md(html)
content = re.sub(r'\\([#*_`])', r'\1', content)
print(content)

@ -0,0 +1,213 @@
import os
import tempfile
import cv2
import numpy as np
from paddleocr import PaddleOCR
from marker.converters.table import TableConverter
from marker.models import create_model_dict
from marker.output import text_from_rendered
from .rapid_table_pipeline.main import table2md_pipeline
from magic_pdf.model.doc_analyze_by_custom_model import doc_analyze
from magic_pdf.data.read_api import read_local_images
from markdownify import markdownify as md
import re
def scanning_document_classify(image):
# 判断是否是扫描件
# 将图像从BGR颜色空间转换到HSV颜色空间
hsv = cv2.cvtColor(image, cv2.COLOR_BGR2HSV)
# 定义红色的HSV范围
lower_red1 = np.array([0, 70, 50])
upper_red1 = np.array([10, 255, 255])
lower_red2 = np.array([170, 70, 50])
upper_red2 = np.array([180, 255, 255])
# 创建两个掩码,一个用于低色调的红色,一个用于高色调的红色
mask1 = cv2.inRange(hsv, lower_red1, upper_red1)
mask2 = cv2.inRange(hsv, lower_red2, upper_red2)
# 将两个掩码合并
mask = cv2.bitwise_or(mask1, mask2)
# 计算红色区域的非零像素数量
non_zero_pixels = cv2.countNonZero(mask)
return 1 < non_zero_pixels < 1000
def remove_watermark(image):
# 去除红色印章
_, _, r_channel = cv2.split(image)
r_channel[r_channel > 210] = 255
r_channel = cv2.cvtColor(r_channel, cv2.COLOR_GRAY2BGR)
return r_channel
def html2md(html_content):
md_content = md(html_content)
md_content = re.sub(r'\\([#*_`])', r'\1', md_content)
return md_content
def markdown_rec(image):
# TODO 可以传入文件夹
image_path = f'{tempfile.mktemp()}.jpg'
cv2.imwrite(image_path, image)
try:
ds = read_local_images(image_path)[0]
x = ds.apply(doc_analyze, ocr=True)
x = x.pipe_ocr_mode(None)
html = x.get_markdown(None)
finally:
os.remove(image_path)
return html2md(html)
ocr = PaddleOCR(lang='ch') # need to run only once to download and load model into memory
def text_rec(image):
result = ocr.ocr(image, cls=True)
output = []
for idx in range(len(result)):
res = result[idx]
if not res:
continue
for line in res:
if not line:
continue
output.append(line[1][0])
return output
def table_rec(image):
return table2md_pipeline(image)
table_converter = TableConverter(artifact_dict=create_model_dict())
def scanning_document_rec(image):
# TODO 内部的ocr可以替换为paddleocr以提升文字识别精度
image_path = f'{tempfile.mktemp()}.jpg'
cv2.imwrite(image_path, image)
try:
no_watermark_image = remove_watermark(cv2.imread(image_path))
prefix, suffix = image_path.split('.')
new_image_path = f'{prefix}_remove_watermark.{suffix}'
cv2.imwrite(new_image_path, no_watermark_image)
rendered = table_converter(new_image_path)
text, _, _ = text_from_rendered(rendered)
finally:
os.remove(image_path)
return text, no_watermark_image
def compute_box_distance(box1, box2):
x11, y11, x12, y12 = box1
x21, y21, x22, y22 = box2
# 计算水平和垂直方向的重叠量
x_overlap = max(0, min(x12, x22) - max(x11, x21))
y_overlap = max(0, min(y12, y22) - max(y11, y21))
# 如果有重叠x和y都重叠返回负的重叠深度取 min 表示最小穿透)
if x_overlap > 0 and y_overlap > 0:
return -min(x_overlap, y_overlap)
distances = []
# 如果 x 方向有投影重叠,计算上下边的距离
if x12 > x21 and x11 < x22:
dist_top = y21 - y12 # box1下边到box2上边
dist_bottom = y11 - y22 # box1上边到box2下边
if dist_top > 0:
distances.append(dist_top)
if dist_bottom > 0:
distances.append(dist_bottom)
# 如果 y 方向有投影重叠,计算左右边的距离
if y12 > y21 and y11 < y22:
dist_left = x11 - x22 # box1左边到box2右边
dist_right = x21 - x12 # box1右边到box2左边
if dist_left > 0:
distances.append(dist_left)
if dist_right > 0:
distances.append(dist_right)
# 如果有合法的距离,返回最小值,否则说明边无法对齐,返回 None
return min(distances) if distances else None
def assign_tables_to_titles(layout_results, max_distance=200):
tables = [_ for _ in layout_results if _.clsid == 4]
titles = [_ for _ in layout_results if _.clsid == 5]
table_to_title = {}
title_to_table = {}
changed = True
while changed:
changed = False
for title in titles:
title_id = id(title)
best_table = None
min_dist = float('inf')
for table in tables:
table_id = id(table)
dist = compute_box_distance(title.box, table.box)
if dist is None or dist > max_distance:
continue
if dist < min_dist:
min_dist = dist
best_table = table
if best_table is None:
continue
table_id = id(best_table)
current_table = title_to_table.get(title_id)
if current_table is best_table:
continue # 已是最优,无需更新
prev_title = table_to_title.get(table_id)
if prev_title:
prev_title_id = id(prev_title)
prev_dist = compute_box_distance(prev_title.box, best_table.box)
if prev_dist is not None and prev_dist <= min_dist:
continue # 原标题绑定得更近,跳过
# 解绑旧标题
title_to_table.pop(prev_title_id, None)
# 更新新绑定
title_to_table[title_id] = best_table
table_to_title[table_id] = title
changed = True # 有更新
# 最终写回绑定结果
for table in tables:
table_id = id(table)
title = table_to_title.get(table_id)
if title:
table.table_title = title.content
else:
table.table_title = None
if __name__ == '__main__':
# content = text_rec('/mnt/research/PaddleOCR/pdf2md_pipeline/s4_content_recognition/all_layouts/5.jpg')
# content = markdown_rec('/mnt/research/PaddleOCR/pdf2md_pipeline/s4_content_recognition/all_layouts/3.jpg')
# content = table_rec('/mnt/research/PaddleOCR/pdf2md_pipeline/s4_content_recognition/all_layouts/6.jpg')
content = scanning_document_rec('/mnt/research/PaddleOCR/pdf2md_pipeline/s4_content_recognition/all_layouts/103.jpg')
print(content)

@ -0,0 +1,128 @@
import psycopg2
from psycopg2 import OperationalError, extras
from loguru import logger
import os
import traceback
# 创建连接(你只需要创建一次,然后在服务中复用)
def create_connection():
try:
conn = psycopg2.connect(
dbname=os.environ['POSTGRESQL_DATABASE'],
user=os.environ['POSTGRESQL_USERNAME'],
password=os.environ['POSTGRESQL_PASSWORD'],
host=os.environ['POSTGRESQL_HOST'],
port=os.environ['POSTGRESQL_PORT']
)
conn.autocommit = False
with conn.cursor() as cur:
cur.execute("SET TIME ZONE 'Asia/Shanghai';")
conn.commit()
return conn
except OperationalError as e:
logger.error(f"连接数据库失败: {e}")
return None
# 插入数据的函数
def insert_data(conn, table, data_dict, pk_name='id'):
"""
向指定表中插入数据
:param conn: psycopg2 connection 对象
:param table: 表名字符串
:param data_dict: 字典格式的数据比如 {"column1": value1, "column2": value2}
"""
if conn is None:
logger.error("数据库连接无效")
return
try:
with conn.cursor() as cur:
columns = data_dict.keys()
values = [data_dict[col] for col in columns]
# 构造 SQL
insert_query = f"""
INSERT INTO {table} ({', '.join(columns)})
VALUES ({', '.join(['%s'] * len(values))})
RETURNING {pk_name}
"""
cur.execute(insert_query, values)
inserted_id = cur.fetchone()[0]
return inserted_id
except Exception as e:
logger.error(f"插入数据失败: {e}")
raise e
def insert_multiple_data(conn, table, data_list, batch_size=100):
"""
批量向指定表中插入数据使用 execute_values
:param conn: psycopg2 connection 对象
:param table: 表名字符串
:param data_list: 包含多个字典的列表每个字典代表一行数据比如 [{"column1": value1, "column2": value2}, ...]
"""
if conn is None:
logger.error("数据库连接无效")
return
try:
with conn.cursor() as cur:
columns = data_list[0].keys()
insert_query = f"""
INSERT INTO {table} ({', '.join(columns)})
VALUES %s
"""
for i in range(0, len(data_list), batch_size):
batch = data_list[i:i + batch_size]
values = [tuple(d.values()) for d in batch]
extras.execute_values(cur, insert_query, values)
except Exception as e:
logger.error(f"批量插入数据失败: {e}")
raise e
conn = create_connection()
def insert_pdf2md_table(pdf_path, pdf_name, process_status, start_time, end_time, rec_results):
data_dict = {
'path': pdf_path,
'filename': pdf_name,
'process_status': process_status,
'analysis_start_time': start_time,
'analysis_end_time': end_time
}
try:
inserted_id = insert_data(conn, 'pdf_info', data_dict)
if process_status == 2:
data_list = []
for i in range(len(rec_results)):
# 每一页
page_no = i + 1
for j in range(len(rec_results[i])):
# 每一个box
box = rec_results[i][j]
content = box.content
clsid = box.clsid
table_title = box.table_title
order = j
data_dict = {
'layout_type': clsid,
'content': content,
'page_no': page_no,
'pdf_id': inserted_id,
'table_title': table_title,
'display_order': order
}
data_list.append(data_dict)
insert_multiple_data(conn, 'pdf_analysis_output', data_list)
conn.commit()
return inserted_id
except Exception as e:
conn.rollback()
logger.error(f'operate database error!\n{traceback.format_exc()}')
raise e

@ -0,0 +1,47 @@
from typing import List
from pdf2image import convert_from_path
import os
import paddleclas
import cv2
from .page_detection.utils import PageDetectionResult
paddle_clas_model = paddleclas.PaddleClas(model_name="text_image_orientation")
def pdf2image(pdf_path, output_dir):
if not os.path.isdir(output_dir):
os.makedirs(output_dir)
images = convert_from_path(pdf_path)
for i, image in enumerate(images):
image.save(f'{output_dir}/{i + 1}.jpg')
def image_orient_cls(input_data):
return paddle_clas_model.predict(input_data)
def page_detection_visual(page_detection_result: PageDetectionResult):
img = cv2.imread(page_detection_result.image_path)
for box in page_detection_result.boxes:
pos = box.pos
clsid = box.clsid
confidence = box.confidence
if clsid == 0:
color = (0, 0, 0)
text = 'text'
elif clsid == 1:
color = (255, 0, 0)
text = 'title'
elif clsid == 2:
color = (0, 255, 0)
text = 'figure'
elif clsid == 4:
color = (0, 0, 255)
text = 'table'
if clsid == 5:
color = (255, 0, 255)
text = 'table caption'
text = f'{text} {confidence}'
img = cv2.rectangle(img, (int(pos[0]), int(pos[1])), (int(pos[2]), int(pos[3])), color, 2)
cv2.putText(img, text, (int(pos[0]), int(pos[1])), cv2.FONT_HERSHEY_TRIPLEX, 1, color, 2)
return img

@ -0,0 +1,262 @@
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
import paddle
import paddle.nn as nn
from scipy.special import softmax
from scipy.interpolate import InterpolatedUnivariateSpline
def line_iou(pred, target, img_w, length=15, aligned=True):
'''
Calculate the line iou value between predictions and targets
Args:
pred: lane predictions, shape: (num_pred, 72)
target: ground truth, shape: (num_target, 72)
img_w: image width
length: extended radius
aligned: True for iou loss calculation, False for pair-wise ious in assign
'''
px1 = pred - length
px2 = pred + length
tx1 = target - length
tx2 = target + length
if aligned:
invalid_mask = target
ovr = paddle.minimum(px2, tx2) - paddle.maximum(px1, tx1)
union = paddle.maximum(px2, tx2) - paddle.minimum(px1, tx1)
else:
num_pred = pred.shape[0]
invalid_mask = target.tile([num_pred, 1, 1])
ovr = (paddle.minimum(px2[:, None, :], tx2[None, ...]) - paddle.maximum(
px1[:, None, :], tx1[None, ...]))
union = (paddle.maximum(px2[:, None, :], tx2[None, ...]) -
paddle.minimum(px1[:, None, :], tx1[None, ...]))
invalid_masks = (invalid_mask < 0) | (invalid_mask >= img_w)
ovr[invalid_masks] = 0.
union[invalid_masks] = 0.
iou = ovr.sum(axis=-1) / (union.sum(axis=-1) + 1e-9)
return iou
class Lane:
def __init__(self, points=None, invalid_value=-2., metadata=None):
super(Lane, self).__init__()
self.curr_iter = 0
self.points = points
self.invalid_value = invalid_value
self.function = InterpolatedUnivariateSpline(
points[:, 1], points[:, 0], k=min(3, len(points) - 1))
self.min_y = points[:, 1].min() - 0.01
self.max_y = points[:, 1].max() + 0.01
self.metadata = metadata or {}
def __repr__(self):
return '[Lane]\n' + str(self.points) + '\n[/Lane]'
def __call__(self, lane_ys):
lane_xs = self.function(lane_ys)
lane_xs[(lane_ys < self.min_y) | (lane_ys > self.max_y
)] = self.invalid_value
return lane_xs
def to_array(self, sample_y_range, img_w, img_h):
self.sample_y = range(sample_y_range[0], sample_y_range[1],
sample_y_range[2])
sample_y = self.sample_y
img_w, img_h = img_w, img_h
ys = np.array(sample_y) / float(img_h)
xs = self(ys)
valid_mask = (xs >= 0) & (xs < 1)
lane_xs = xs[valid_mask] * img_w
lane_ys = ys[valid_mask] * img_h
lane = np.concatenate(
(lane_xs.reshape(-1, 1), lane_ys.reshape(-1, 1)), axis=1)
return lane
def __iter__(self):
return self
def __next__(self):
if self.curr_iter < len(self.points):
self.curr_iter += 1
return self.points[self.curr_iter - 1]
self.curr_iter = 0
raise StopIteration
class CLRNetPostProcess(object):
"""
Args:
input_shape (int): network input image size
ori_shape (int): ori image shape of before padding
scale_factor (float): scale factor of ori image
enable_mkldnn (bool): whether to open MKLDNN
"""
def __init__(self, img_w, ori_img_h, cut_height, conf_threshold, nms_thres,
max_lanes, num_points):
self.img_w = img_w
self.conf_threshold = conf_threshold
self.nms_thres = nms_thres
self.max_lanes = max_lanes
self.num_points = num_points
self.n_strips = num_points - 1
self.n_offsets = num_points
self.ori_img_h = ori_img_h
self.cut_height = cut_height
self.prior_ys = paddle.linspace(
start=1, stop=0, num=self.n_offsets).astype('float64')
def predictions_to_pred(self, predictions):
"""
Convert predictions to internal Lane structure for evaluation.
"""
lanes = []
for lane in predictions:
lane_xs = lane[6:].clone()
start = min(
max(0, int(round(lane[2].item() * self.n_strips))),
self.n_strips)
length = int(round(lane[5].item()))
end = start + length - 1
end = min(end, len(self.prior_ys) - 1)
if start > 0:
mask = ((lane_xs[:start] >= 0.) &
(lane_xs[:start] <= 1.)).cpu().detach().numpy()[::-1]
mask = ~((mask.cumprod()[::-1]).astype(np.bool_))
lane_xs[:start][mask] = -2
if end < len(self.prior_ys) - 1:
lane_xs[end + 1:] = -2
lane_ys = self.prior_ys[lane_xs >= 0].clone()
lane_xs = lane_xs[lane_xs >= 0]
lane_xs = lane_xs.flip(axis=0).astype('float64')
lane_ys = lane_ys.flip(axis=0)
lane_ys = (lane_ys *
(self.ori_img_h - self.cut_height) + self.cut_height
) / self.ori_img_h
if len(lane_xs) <= 1:
continue
points = paddle.stack(
x=(lane_xs.reshape([-1, 1]), lane_ys.reshape([-1, 1])),
axis=1).squeeze(axis=2)
lane = Lane(
points=points.cpu().numpy(),
metadata={
'start_x': lane[3],
'start_y': lane[2],
'conf': lane[1]
})
lanes.append(lane)
return lanes
def lane_nms(self, predictions, scores, nms_overlap_thresh, top_k):
"""
NMS for lane detection.
predictions: paddle.Tensor [num_lanes,conf,y,x,lenght,72offsets] [12,77]
scores: paddle.Tensor [num_lanes]
nms_overlap_thresh: float
top_k: int
"""
# sort by scores to get idx
idx = scores.argsort(descending=True)
keep = []
condidates = predictions.clone()
condidates = condidates.index_select(idx)
while len(condidates) > 0:
keep.append(idx[0])
if len(keep) >= top_k or len(condidates) == 1:
break
ious = []
for i in range(1, len(condidates)):
ious.append(1 - line_iou(
condidates[i].unsqueeze(0),
condidates[0].unsqueeze(0),
img_w=self.img_w,
length=15))
ious = paddle.to_tensor(ious)
mask = ious <= nms_overlap_thresh
id = paddle.where(mask == False)[0]
if id.shape[0] == 0:
break
condidates = condidates[1:].index_select(id)
idx = idx[1:].index_select(id)
keep = paddle.stack(keep)
return keep
def get_lanes(self, output, as_lanes=True):
"""
Convert model output to lanes.
"""
softmax = nn.Softmax(axis=1)
decoded = []
for predictions in output:
if len(predictions) == 0:
decoded.append([])
continue
threshold = self.conf_threshold
scores = softmax(predictions[:, :2])[:, 1]
keep_inds = scores >= threshold
predictions = predictions[keep_inds]
scores = scores[keep_inds]
if predictions.shape[0] == 0:
decoded.append([])
continue
nms_predictions = predictions.detach().clone()
nms_predictions = paddle.concat(
x=[nms_predictions[..., :4], nms_predictions[..., 5:]], axis=-1)
nms_predictions[..., 4] = nms_predictions[..., 4] * self.n_strips
nms_predictions[..., 5:] = nms_predictions[..., 5:] * (
self.img_w - 1)
keep = self.lane_nms(
nms_predictions[..., 5:],
scores,
nms_overlap_thresh=self.nms_thres,
top_k=self.max_lanes)
predictions = predictions.index_select(keep)
if predictions.shape[0] == 0:
decoded.append([])
continue
predictions[:, 5] = paddle.round(predictions[:, 5] * self.n_strips)
if as_lanes:
pred = self.predictions_to_pred(predictions)
else:
pred = predictions
decoded.append(pred)
return decoded
def __call__(self, lanes_list):
lanes = self.get_lanes(lanes_list)
return lanes

@ -0,0 +1,243 @@
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
this code is based on https://github.com/open-mmlab/mmpose/mmpose/core/post_processing/post_transforms.py
"""
import cv2
import numpy as np
class EvalAffine(object):
def __init__(self, size, stride=64):
super(EvalAffine, self).__init__()
self.size = size
self.stride = stride
def __call__(self, image, im_info):
s = self.size
h, w, _ = image.shape
trans, size_resized = get_affine_mat_kernel(h, w, s, inv=False)
image_resized = cv2.warpAffine(image, trans, size_resized)
return image_resized, im_info
def get_affine_mat_kernel(h, w, s, inv=False):
if w < h:
w_ = s
h_ = int(np.ceil((s / w * h) / 64.) * 64)
scale_w = w
scale_h = h_ / w_ * w
else:
h_ = s
w_ = int(np.ceil((s / h * w) / 64.) * 64)
scale_h = h
scale_w = w_ / h_ * h
center = np.array([np.round(w / 2.), np.round(h / 2.)])
size_resized = (w_, h_)
trans = get_affine_transform(
center, np.array([scale_w, scale_h]), 0, size_resized, inv=inv)
return trans, size_resized
def get_affine_transform(center,
input_size,
rot,
output_size,
shift=(0., 0.),
inv=False):
"""Get the affine transform matrix, given the center/scale/rot/output_size.
Args:
center (np.ndarray[2, ]): Center of the bounding box (x, y).
scale (np.ndarray[2, ]): Scale of the bounding box
wrt [width, height].
rot (float): Rotation angle (degree).
output_size (np.ndarray[2, ]): Size of the destination heatmaps.
shift (0-100%): Shift translation ratio wrt the width/height.
Default (0., 0.).
inv (bool): Option to inverse the affine transform direction.
(inv=False: src->dst or inv=True: dst->src)
Returns:
np.ndarray: The transform matrix.
"""
assert len(center) == 2
assert len(output_size) == 2
assert len(shift) == 2
if not isinstance(input_size, (np.ndarray, list)):
input_size = np.array([input_size, input_size], dtype=np.float32)
scale_tmp = input_size
shift = np.array(shift)
src_w = scale_tmp[0]
dst_w = output_size[0]
dst_h = output_size[1]
rot_rad = np.pi * rot / 180
src_dir = rotate_point([0., src_w * -0.5], rot_rad)
dst_dir = np.array([0., dst_w * -0.5])
src = np.zeros((3, 2), dtype=np.float32)
src[0, :] = center + scale_tmp * shift
src[1, :] = center + src_dir + scale_tmp * shift
src[2, :] = _get_3rd_point(src[0, :], src[1, :])
dst = np.zeros((3, 2), dtype=np.float32)
dst[0, :] = [dst_w * 0.5, dst_h * 0.5]
dst[1, :] = np.array([dst_w * 0.5, dst_h * 0.5]) + dst_dir
dst[2, :] = _get_3rd_point(dst[0, :], dst[1, :])
if inv:
trans = cv2.getAffineTransform(np.float32(dst), np.float32(src))
else:
trans = cv2.getAffineTransform(np.float32(src), np.float32(dst))
return trans
def get_warp_matrix(theta, size_input, size_dst, size_target):
"""This code is based on
https://github.com/open-mmlab/mmpose/blob/master/mmpose/core/post_processing/post_transforms.py
Calculate the transformation matrix under the constraint of unbiased.
Paper ref: Huang et al. The Devil is in the Details: Delving into Unbiased
Data Processing for Human Pose Estimation (CVPR 2020).
Args:
theta (float): Rotation angle in degrees.
size_input (np.ndarray): Size of input image [w, h].
size_dst (np.ndarray): Size of output image [w, h].
size_target (np.ndarray): Size of ROI in input plane [w, h].
Returns:
matrix (np.ndarray): A matrix for transformation.
"""
theta = np.deg2rad(theta)
matrix = np.zeros((2, 3), dtype=np.float32)
scale_x = size_dst[0] / size_target[0]
scale_y = size_dst[1] / size_target[1]
matrix[0, 0] = np.cos(theta) * scale_x
matrix[0, 1] = -np.sin(theta) * scale_x
matrix[0, 2] = scale_x * (
-0.5 * size_input[0] * np.cos(theta) + 0.5 * size_input[1] *
np.sin(theta) + 0.5 * size_target[0])
matrix[1, 0] = np.sin(theta) * scale_y
matrix[1, 1] = np.cos(theta) * scale_y
matrix[1, 2] = scale_y * (
-0.5 * size_input[0] * np.sin(theta) - 0.5 * size_input[1] *
np.cos(theta) + 0.5 * size_target[1])
return matrix
def rotate_point(pt, angle_rad):
"""Rotate a point by an angle.
Args:
pt (list[float]): 2 dimensional point to be rotated
angle_rad (float): rotation angle by radian
Returns:
list[float]: Rotated point.
"""
assert len(pt) == 2
sn, cs = np.sin(angle_rad), np.cos(angle_rad)
new_x = pt[0] * cs - pt[1] * sn
new_y = pt[0] * sn + pt[1] * cs
rotated_pt = [new_x, new_y]
return rotated_pt
def _get_3rd_point(a, b):
"""To calculate the affine matrix, three pairs of points are required. This
function is used to get the 3rd point, given 2D points a & b.
The 3rd point is defined by rotating vector `a - b` by 90 degrees
anticlockwise, using b as the rotation center.
Args:
a (np.ndarray): point(x,y)
b (np.ndarray): point(x,y)
Returns:
np.ndarray: The 3rd point.
"""
assert len(a) == 2
assert len(b) == 2
direction = a - b
third_pt = b + np.array([-direction[1], direction[0]], dtype=np.float32)
return third_pt
class TopDownEvalAffine(object):
"""apply affine transform to image and coords
Args:
trainsize (list): [w, h], the standard size used to train
use_udp (bool): whether to use Unbiased Data Processing.
records(dict): the dict contained the image and coords
Returns:
records (dict): contain the image and coords after tranformed
"""
def __init__(self, trainsize, use_udp=False):
self.trainsize = trainsize
self.use_udp = use_udp
def __call__(self, image, im_info):
rot = 0
imshape = im_info['im_shape'][::-1]
center = im_info['center'] if 'center' in im_info else imshape / 2.
scale = im_info['scale'] if 'scale' in im_info else imshape
if self.use_udp:
trans = get_warp_matrix(
rot, center * 2.0,
[self.trainsize[0] - 1.0, self.trainsize[1] - 1.0], scale)
image = cv2.warpAffine(
image,
trans, (int(self.trainsize[0]), int(self.trainsize[1])),
flags=cv2.INTER_LINEAR)
else:
trans = get_affine_transform(center, scale, rot, self.trainsize)
image = cv2.warpAffine(
image,
trans, (int(self.trainsize[0]), int(self.trainsize[1])),
flags=cv2.INTER_LINEAR)
return image, im_info
def expand_crop(images, rect, expand_ratio=0.3):
imgh, imgw, c = images.shape
label, conf, xmin, ymin, xmax, ymax = [int(x) for x in rect.tolist()]
if label != 0:
return None, None, None
org_rect = [xmin, ymin, xmax, ymax]
h_half = (ymax - ymin) * (1 + expand_ratio) / 2.
w_half = (xmax - xmin) * (1 + expand_ratio) / 2.
if h_half > w_half * 4 / 3:
w_half = h_half * 0.75
center = [(ymin + ymax) / 2., (xmin + xmax) / 2.]
ymin = max(0, int(center[0] - h_half))
ymax = min(imgh - 1, int(center[0] + h_half))
xmin = max(0, int(center[1] - w_half))
xmax = min(imgw - 1, int(center[1] + w_half))
return images[ymin:ymax, xmin:xmax, :], [xmin, ymin, xmax, ymax], org_rect

@ -0,0 +1,69 @@
from typing import List
from .pdf_detection import Pipeline
from utils import non_max_suppression, merge_text_and_title_boxes, LayoutBox, PageDetectionResult
from tqdm import tqdm
"""
0 - Text
1 - Title
2 - Figure
3 - Figure caption
4 - Table
5 - Table caption
6 - Header
7 - Footer
8 - Reference
9 - Equation
"""
pipeline = Pipeline('./models/PaddleDetection/inference_model/picodet_lcnet_x1_0_fgd_layout_cdla_infer')
effective_labels = [0, 1, 2, 4, 5]
# nms优先级索引越低优先级越低box重叠时优先保留表格
label_scores = [1, 5, 0, 2, 4]
expand_pixel = 10
def layout_analysis(image_paths) -> List[PageDetectionResult]:
layout_analysis_results = []
for image_path in tqdm(image_paths, '版面分析'):
page_detecion_outputs = pipeline(image_path)
layout_boxes = []
for i in range(len(page_detecion_outputs)):
clsid, box, confidence = page_detecion_outputs[i]
if clsid in effective_labels:
layout_boxes.append(LayoutBox(clsid, box, confidence))
page_detecion_outputs = PageDetectionResult(layout_boxes, image_path)
scores = []
poses = []
for box in page_detecion_outputs.boxes:
# 相同的label重叠时保留面积更大的
area = (box.pos[3] - box.pos[1]) * (box.pos[2] - box.pos[0])
area_score = area / 5000000
scores.append(label_scores.index(box.clsid) + area_score)
poses.append(box.pos)
indices = non_max_suppression(poses, scores, 0.2)
_boxes = []
for i in indices:
_boxes.append(page_detecion_outputs.boxes[i])
page_detecion_outputs.boxes = _boxes
for i in range(len(page_detecion_outputs.boxes) - 1, -1, -1):
box = page_detecion_outputs.boxes[i]
if box.clsid in (0, 5):
# 移除Table box和Figure box中的Table caption box和Text box (有些扫描件会被识别为Figure)
for _box in page_detecion_outputs.boxes:
if _box.clsid != 2 and _box.clsid != 4:
continue
if box.pos[0] > _box.pos[0] and box.pos[1] > _box.pos[1] and box.pos[2] < _box.pos[2] and box.pos[3] < _box.pos[3]:
page_detecion_outputs.boxes.remove(box)
# 将text和title合并起来便于转成markdown格式
page_detecion_outputs.boxes = merge_text_and_title_boxes(page_detecion_outputs.boxes, (0, 1))
# 对box进行排序
page_detecion_outputs.boxes.sort(key=lambda x: (x.pos[1], x.pos[0]))
layout_analysis_results.append(page_detecion_outputs)
return layout_analysis_results

@ -0,0 +1,918 @@
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import yaml
import glob
import json
from pathlib import Path
import cv2
import numpy as np
import math
import paddle
from paddle.inference import Config
from paddle.inference import create_predictor
import sys
# add deploy path of PaddleDetection to sys.path
parent_path = os.path.abspath(os.path.join(__file__, *(['..'])))
sys.path.insert(0, parent_path)
from preprocess import preprocess, Resize, NormalizeImage, Permute, PadStride, LetterBoxResize, WarpAffine, Pad, decode_image, CULaneResize
from picodet_postprocess import PicoDetPostProcess
from clrnet_postprocess import CLRNetPostProcess
from visualize import visualize_box_mask, imshow_lanes
from utils import argsparser, Timer, multiclass_nms, coco_clsid2catid
# Global dictionary
SUPPORT_MODELS = {
'YOLO', 'PPYOLOE', 'RCNN', 'SSD', 'Face', 'FCOS', 'SOLOv2', 'TTFNet',
'S2ANet', 'JDE', 'FairMOT', 'DeepSORT', 'GFL', 'PicoDet', 'CenterNet',
'TOOD', 'RetinaNet', 'StrongBaseline', 'STGCN', 'YOLOX', 'YOLOF', 'PPHGNet',
'PPLCNet', 'DETR', 'CenterTrack', 'CLRNet'
}
class Detector(object):
"""
Args:
pred_config (object): config of model, defined by `Config(model_dir)`
model_dir (str): root path of model.pdiparams, model.pdmodel and infer_cfg.yml
device (str): Choose the device you want to run, it can be: CPU/GPU/XPU, default is CPU
run_mode (str): mode of running(paddle/trt_fp32/trt_fp16)
batch_size (int): size of pre batch in inference
trt_min_shape (int): min shape for dynamic shape in trt
trt_max_shape (int): max shape for dynamic shape in trt
trt_opt_shape (int): opt shape for dynamic shape in trt
trt_calib_mode (bool): If the model is produced by TRT offline quantitative
calibration, trt_calib_mode need to set True
cpu_threads (int): cpu threads
enable_mkldnn (bool): whether to open MKLDNN
enable_mkldnn_bfloat16 (bool): whether to turn on mkldnn bfloat16
output_dir (str): The path of output
threshold (float): The threshold of score for visualization
delete_shuffle_pass (bool): whether to remove shuffle_channel_detect_pass in TensorRT.
Used by action model.
"""
def __init__(self,
model_dir,
device='CPU',
run_mode='paddle',
batch_size=1,
trt_min_shape=1,
trt_max_shape=1280,
trt_opt_shape=640,
trt_calib_mode=False,
cpu_threads=1,
enable_mkldnn=False,
enable_mkldnn_bfloat16=False,
output_dir='output',
threshold=0.5,
delete_shuffle_pass=False,
use_fd_format=False):
self.pred_config = self.set_config(
model_dir, use_fd_format=use_fd_format)
self.predictor, self.config = load_predictor(
model_dir,
self.pred_config.arch,
run_mode=run_mode,
batch_size=batch_size,
min_subgraph_size=self.pred_config.min_subgraph_size,
device=device,
use_dynamic_shape=self.pred_config.use_dynamic_shape,
trt_min_shape=trt_min_shape,
trt_max_shape=trt_max_shape,
trt_opt_shape=trt_opt_shape,
trt_calib_mode=trt_calib_mode,
cpu_threads=cpu_threads,
enable_mkldnn=enable_mkldnn,
enable_mkldnn_bfloat16=enable_mkldnn_bfloat16,
delete_shuffle_pass=delete_shuffle_pass)
self.det_times = Timer()
self.batch_size = batch_size
self.output_dir = output_dir
self.threshold = threshold
self.device = device
def set_config(self, model_dir, use_fd_format):
return PredictConfig(model_dir, use_fd_format=use_fd_format)
def preprocess(self, image_list):
preprocess_ops = []
for op_info in self.pred_config.preprocess_infos:
new_op_info = op_info.copy()
op_type = new_op_info.pop('type')
preprocess_ops.append(eval(op_type)(**new_op_info))
input_im_lst = []
input_im_info_lst = []
for im_path in image_list:
im, im_info = preprocess(im_path, preprocess_ops)
input_im_lst.append(im)
input_im_info_lst.append(im_info)
inputs = create_inputs(input_im_lst, input_im_info_lst)
input_names = self.predictor.get_input_names()
for i in range(len(input_names)):
input_tensor = self.predictor.get_input_handle(input_names[i])
if input_names[i] == 'x':
input_tensor.copy_from_cpu(inputs['image'])
else:
input_tensor.copy_from_cpu(inputs[input_names[i]])
return inputs
def postprocess(self, inputs, result):
# postprocess output of predictor
np_boxes_num = result['boxes_num']
assert isinstance(np_boxes_num, np.ndarray), \
'`np_boxes_num` should be a `numpy.ndarray`'
result = {k: v for k, v in result.items() if v is not None}
return result
def filter_box(self, result, threshold):
np_boxes_num = result['boxes_num']
boxes = result['boxes']
start_idx = 0
filter_boxes = []
filter_num = []
for i in range(len(np_boxes_num)):
boxes_num = np_boxes_num[i]
boxes_i = boxes[start_idx:start_idx + boxes_num, :]
idx = boxes_i[:, 1] > threshold
filter_boxes_i = boxes_i[idx, :]
filter_boxes.append(filter_boxes_i)
filter_num.append(filter_boxes_i.shape[0])
start_idx += boxes_num
boxes = np.concatenate(filter_boxes)
filter_num = np.array(filter_num)
filter_res = {'boxes': boxes, 'boxes_num': filter_num}
return filter_res
def predict(self, repeats=1, run_benchmark=False):
'''
Args:
repeats (int): repeats number for prediction
Returns:
result (dict): include 'boxes': np.ndarray: shape:[N,6], N: number of box,
matix element:[class, score, x_min, y_min, x_max, y_max]
MaskRCNN's result include 'masks': np.ndarray:
shape: [N, im_h, im_w]
'''
# model prediction
np_boxes_num, np_boxes, np_masks = np.array([0]), None, None
if run_benchmark:
for i in range(repeats):
self.predictor.run()
if self.device == 'GPU':
paddle.device.cuda.synchronize()
else:
paddle.device.synchronize(device=self.device.lower())
result = dict(
boxes=np_boxes, masks=np_masks, boxes_num=np_boxes_num)
return result
for i in range(repeats):
self.predictor.run()
output_names = self.predictor.get_output_names()
boxes_tensor = self.predictor.get_output_handle(output_names[0])
np_boxes = boxes_tensor.copy_to_cpu()
if len(output_names) == 1:
# some exported model can not get tensor 'bbox_num'
np_boxes_num = np.array([len(np_boxes)])
else:
boxes_num = self.predictor.get_output_handle(output_names[1])
np_boxes_num = boxes_num.copy_to_cpu()
if self.pred_config.mask:
masks_tensor = self.predictor.get_output_handle(output_names[2])
np_masks = masks_tensor.copy_to_cpu()
result = dict(boxes=np_boxes, masks=np_masks, boxes_num=np_boxes_num)
return result
def merge_batch_result(self, batch_result):
if len(batch_result) == 1:
return batch_result[0]
res_key = batch_result[0].keys()
results = {k: [] for k in res_key}
for res in batch_result:
for k, v in res.items():
results[k].append(v)
for k, v in results.items():
if k not in ['masks', 'segm']:
results[k] = np.concatenate(v)
return results
def get_timer(self):
return self.det_times
def predict_image(self,
image_list,
threshold=0.5,
visual=True):
batch_loop_cnt = math.ceil(float(len(image_list)) / self.batch_size)
results = []
for i in range(batch_loop_cnt):
start_index = i * self.batch_size
end_index = min((i + 1) * self.batch_size, len(image_list))
batch_image_list = image_list[start_index:end_index]
# preprocess
self.det_times.preprocess_time_s.start()
inputs = self.preprocess(batch_image_list)
self.det_times.preprocess_time_s.end()
# model prediction
self.det_times.inference_time_s.start()
result = self.predict()
self.det_times.inference_time_s.end()
# postprocess
self.det_times.postprocess_time_s.start()
result = self.postprocess(inputs, result)
self.det_times.postprocess_time_s.end()
self.det_times.img_num += len(batch_image_list)
if visual:
visualize(
batch_image_list,
result,
self.pred_config.labels,
output_dir=self.output_dir,
threshold=self.threshold)
# TODO 在这里处理batch
results.append(result)
results = self.merge_batch_result(results)
boxes = results['boxes']
expect_boxes = (boxes[:, 1] > threshold) & (boxes[:, 0] > -1)
boxes = boxes[expect_boxes, :]
output = []
for dt in boxes:
clsid, box, confidence = int(dt[0]), dt[2:].tolist(), dt[1]
output.append((clsid, box, confidence))
return output
class DetectorSOLOv2(Detector):
"""
Args:
model_dir (str): root path of model.pdiparams, model.pdmodel and infer_cfg.yml
device (str): Choose the device you want to run, it can be: CPU/GPU/XPU, default is CPU
run_mode (str): mode of running(paddle/trt_fp32/trt_fp16)
batch_size (int): size of pre batch in inference
trt_min_shape (int): min shape for dynamic shape in trt
trt_max_shape (int): max shape for dynamic shape in trt
trt_opt_shape (int): opt shape for dynamic shape in trt
trt_calib_mode (bool): If the model is produced by TRT offline quantitative
calibration, trt_calib_mode need to set True
cpu_threads (int): cpu threads
enable_mkldnn (bool): whether to open MKLDNN
enable_mkldnn_bfloat16 (bool): Whether to turn on mkldnn bfloat16
output_dir (str): The path of output
threshold (float): The threshold of score for visualization
"""
def __init__(self,
model_dir,
device='CPU',
run_mode='paddle',
batch_size=1,
trt_min_shape=1,
trt_max_shape=1280,
trt_opt_shape=640,
trt_calib_mode=False,
cpu_threads=1,
enable_mkldnn=False,
enable_mkldnn_bfloat16=False,
output_dir='./',
threshold=0.5,
use_fd_format=False):
super(DetectorSOLOv2, self).__init__(
model_dir=model_dir,
device=device,
run_mode=run_mode,
batch_size=batch_size,
trt_min_shape=trt_min_shape,
trt_max_shape=trt_max_shape,
trt_opt_shape=trt_opt_shape,
trt_calib_mode=trt_calib_mode,
cpu_threads=cpu_threads,
enable_mkldnn=enable_mkldnn,
enable_mkldnn_bfloat16=enable_mkldnn_bfloat16,
output_dir=output_dir,
threshold=threshold,
use_fd_format=use_fd_format)
def predict(self, repeats=1, run_benchmark=False):
'''
Args:
repeats (int): repeat number for prediction
Returns:
result (dict): 'segm': np.ndarray,shape:[N, im_h, im_w]
'cate_label': label of segm, shape:[N]
'cate_score': confidence score of segm, shape:[N]
'''
np_segms, np_label, np_score, np_boxes_num = None, None, None, np.array(
[0])
if run_benchmark:
for i in range(repeats):
self.predictor.run()
paddle.device.cuda.synchronize()
result = dict(
segm=np_segms,
label=np_label,
score=np_score,
boxes_num=np_boxes_num)
return result
for i in range(repeats):
self.predictor.run()
output_names = self.predictor.get_output_names()
np_segms = self.predictor.get_output_handle(output_names[
0]).copy_to_cpu()
np_boxes_num = self.predictor.get_output_handle(output_names[
1]).copy_to_cpu()
np_label = self.predictor.get_output_handle(output_names[
2]).copy_to_cpu()
np_score = self.predictor.get_output_handle(output_names[
3]).copy_to_cpu()
result = dict(
segm=np_segms,
label=np_label,
score=np_score,
boxes_num=np_boxes_num)
return result
class DetectorPicoDet(Detector):
"""
Args:
model_dir (str): root path of model.pdiparams, model.pdmodel and infer_cfg.yml
device (str): Choose the device you want to run, it can be: CPU/GPU/XPU, default is CPU
run_mode (str): mode of running(paddle/trt_fp32/trt_fp16)
batch_size (int): size of pre batch in inference
trt_min_shape (int): min shape for dynamic shape in trt
trt_max_shape (int): max shape for dynamic shape in trt
trt_opt_shape (int): opt shape for dynamic shape in trt
trt_calib_mode (bool): If the model is produced by TRT offline quantitative
calibration, trt_calib_mode need to set True
cpu_threads (int): cpu threads
enable_mkldnn (bool): whether to turn on MKLDNN
enable_mkldnn_bfloat16 (bool): whether to turn on MKLDNN_BFLOAT16
"""
def __init__(self,
model_dir,
device='CPU',
run_mode='paddle',
batch_size=1,
trt_min_shape=1,
trt_max_shape=1280,
trt_opt_shape=640,
trt_calib_mode=False,
cpu_threads=1,
enable_mkldnn=False,
enable_mkldnn_bfloat16=False,
output_dir='./',
threshold=0.5,
use_fd_format=False):
super(DetectorPicoDet, self).__init__(
model_dir=model_dir,
device=device,
run_mode=run_mode,
batch_size=batch_size,
trt_min_shape=trt_min_shape,
trt_max_shape=trt_max_shape,
trt_opt_shape=trt_opt_shape,
trt_calib_mode=trt_calib_mode,
cpu_threads=cpu_threads,
enable_mkldnn=enable_mkldnn,
enable_mkldnn_bfloat16=enable_mkldnn_bfloat16,
output_dir=output_dir,
threshold=threshold,
use_fd_format=use_fd_format)
def postprocess(self, inputs, result):
# postprocess output of predictor
np_score_list = result['boxes']
np_boxes_list = result['boxes_num']
postprocessor = PicoDetPostProcess(
inputs['image'].shape[2:],
inputs['im_shape'],
inputs['scale_factor'],
strides=self.pred_config.fpn_stride,
nms_threshold=self.pred_config.nms['nms_threshold'])
np_boxes, np_boxes_num = postprocessor(np_score_list, np_boxes_list)
result = dict(boxes=np_boxes, boxes_num=np_boxes_num)
return result
def predict(self, repeats=1, run_benchmark=False):
'''
Args:
repeats (int): repeat number for prediction
Returns:
result (dict): include 'boxes': np.ndarray: shape:[N,6], N: number of box,
matix element:[class, score, x_min, y_min, x_max, y_max]
'''
np_score_list, np_boxes_list = [], []
if run_benchmark:
for i in range(repeats):
self.predictor.run()
paddle.device.cuda.synchronize()
result = dict(boxes=np_score_list, boxes_num=np_boxes_list)
return result
for i in range(repeats):
self.predictor.run()
np_score_list.clear()
np_boxes_list.clear()
output_names = self.predictor.get_output_names()
num_outs = int(len(output_names) / 2)
for out_idx in range(num_outs):
np_score_list.append(
self.predictor.get_output_handle(output_names[out_idx])
.copy_to_cpu())
np_boxes_list.append(
self.predictor.get_output_handle(output_names[
out_idx + num_outs]).copy_to_cpu())
result = dict(boxes=np_score_list, boxes_num=np_boxes_list)
return result
class DetectorCLRNet(Detector):
"""
Args:
model_dir (str): root path of model.pdiparams, model.pdmodel and infer_cfg.yml
device (str): Choose the device you want to run, it can be: CPU/GPU/XPU, default is CPU
run_mode (str): mode of running(paddle/trt_fp32/trt_fp16)
batch_size (int): size of pre batch in inference
trt_min_shape (int): min shape for dynamic shape in trt
trt_max_shape (int): max shape for dynamic shape in trt
trt_opt_shape (int): opt shape for dynamic shape in trt
trt_calib_mode (bool): If the model is produced by TRT offline quantitative
calibration, trt_calib_mode need to set True
cpu_threads (int): cpu threads
enable_mkldnn (bool): whether to turn on MKLDNN
enable_mkldnn_bfloat16 (bool): whether to turn on MKLDNN_BFLOAT16
"""
def __init__(self,
model_dir,
device='CPU',
run_mode='paddle',
batch_size=1,
trt_min_shape=1,
trt_max_shape=1280,
trt_opt_shape=640,
trt_calib_mode=False,
cpu_threads=1,
enable_mkldnn=False,
enable_mkldnn_bfloat16=False,
output_dir='./',
threshold=0.5,
use_fd_format=False):
super(DetectorCLRNet, self).__init__(
model_dir=model_dir,
device=device,
run_mode=run_mode,
batch_size=batch_size,
trt_min_shape=trt_min_shape,
trt_max_shape=trt_max_shape,
trt_opt_shape=trt_opt_shape,
trt_calib_mode=trt_calib_mode,
cpu_threads=cpu_threads,
enable_mkldnn=enable_mkldnn,
enable_mkldnn_bfloat16=enable_mkldnn_bfloat16,
output_dir=output_dir,
threshold=threshold,
use_fd_format=use_fd_format)
deploy_file = os.path.join(model_dir, 'infer_cfg.yml')
with open(deploy_file) as f:
yml_conf = yaml.safe_load(f)
self.img_w = yml_conf['img_w']
self.ori_img_h = yml_conf['ori_img_h']
self.cut_height = yml_conf['cut_height']
self.max_lanes = yml_conf['max_lanes']
self.nms_thres = yml_conf['nms_thres']
self.num_points = yml_conf['num_points']
self.conf_threshold = yml_conf['conf_threshold']
def postprocess(self, inputs, result):
# postprocess output of predictor
lanes_list = result['lanes']
postprocessor = CLRNetPostProcess(
img_w=self.img_w,
ori_img_h=self.ori_img_h,
cut_height=self.cut_height,
conf_threshold=self.conf_threshold,
nms_thres=self.nms_thres,
max_lanes=self.max_lanes,
num_points=self.num_points)
lanes = postprocessor(lanes_list)
result = dict(lanes=lanes)
return result
def predict(self, repeats=1, run_benchmark=False):
'''
Args:
repeats (int): repeat number for prediction
Returns:
result (dict): include 'boxes': np.ndarray: shape:[N,6], N: number of box,
matix element:[class, score, x_min, y_min, x_max, y_max]
'''
lanes_list = []
if run_benchmark:
for i in range(repeats):
self.predictor.run()
paddle.device.cuda.synchronize()
result = dict(lanes=lanes_list)
return result
for i in range(repeats):
# TODO: check the output of predictor
self.predictor.run()
lanes_list.clear()
output_names = self.predictor.get_output_names()
num_outs = int(len(output_names) / 2)
if num_outs == 0:
lanes_list.append([])
for out_idx in range(num_outs):
lanes_list.append(
self.predictor.get_output_handle(output_names[out_idx])
.copy_to_cpu())
result = dict(lanes=lanes_list)
return result
def create_inputs(imgs, im_info):
"""generate input for different model type
Args:
imgs (list(numpy)): list of images (np.ndarray)
im_info (list(dict)): list of image info
Returns:
inputs (dict): input of model
"""
inputs = {}
im_shape = []
scale_factor = []
if len(imgs) == 1:
inputs['image'] = np.array((imgs[0], )).astype('float32')
inputs['im_shape'] = np.array(
(im_info[0]['im_shape'], )).astype('float32')
inputs['scale_factor'] = np.array(
(im_info[0]['scale_factor'], )).astype('float32')
return inputs
for e in im_info:
im_shape.append(np.array((e['im_shape'], )).astype('float32'))
scale_factor.append(np.array((e['scale_factor'], )).astype('float32'))
inputs['im_shape'] = np.concatenate(im_shape, axis=0)
inputs['scale_factor'] = np.concatenate(scale_factor, axis=0)
imgs_shape = [[e.shape[1], e.shape[2]] for e in imgs]
max_shape_h = max([e[0] for e in imgs_shape])
max_shape_w = max([e[1] for e in imgs_shape])
padding_imgs = []
for img in imgs:
im_c, im_h, im_w = img.shape[:]
padding_im = np.zeros(
(im_c, max_shape_h, max_shape_w), dtype=np.float32)
padding_im[:, :im_h, :im_w] = img
padding_imgs.append(padding_im)
inputs['image'] = np.stack(padding_imgs, axis=0)
return inputs
class PredictConfig():
"""set config of preprocess, postprocess and visualize
Args:
model_dir (str): root path of model.yml
"""
def __init__(self, model_dir, use_fd_format=False):
# parsing Yaml config for Preprocess
fd_deploy_file = os.path.join(model_dir, 'inference.yml')
ppdet_deploy_file = os.path.join(model_dir, 'infer_cfg.yml')
if use_fd_format:
if not os.path.exists(fd_deploy_file) and os.path.exists(
ppdet_deploy_file):
raise RuntimeError(
"Non-FD format model detected. Please set `use_fd_format` to False."
)
deploy_file = fd_deploy_file
else:
if not os.path.exists(ppdet_deploy_file) and os.path.exists(
fd_deploy_file):
raise RuntimeError(
"FD format model detected. Please set `use_fd_format` to False."
)
deploy_file = ppdet_deploy_file
with open(deploy_file) as f:
yml_conf = yaml.safe_load(f)
self.check_model(yml_conf)
self.arch = yml_conf['arch']
self.preprocess_infos = yml_conf['Preprocess']
self.min_subgraph_size = yml_conf['min_subgraph_size']
self.labels = yml_conf['label_list']
self.mask = False
self.use_dynamic_shape = yml_conf['use_dynamic_shape']
if 'mask' in yml_conf:
self.mask = yml_conf['mask']
self.tracker = None
if 'tracker' in yml_conf:
self.tracker = yml_conf['tracker']
if 'NMS' in yml_conf:
self.nms = yml_conf['NMS']
if 'fpn_stride' in yml_conf:
self.fpn_stride = yml_conf['fpn_stride']
if self.arch == 'RCNN' and yml_conf.get('export_onnx', False):
print(
'The RCNN export model is used for ONNX and it only supports batch_size = 1'
)
self.print_config()
def check_model(self, yml_conf):
"""
Raises:
ValueError: loaded model not in supported model type
"""
for support_model in SUPPORT_MODELS:
if support_model in yml_conf['arch']:
return True
raise ValueError("Unsupported arch: {}, expect {}".format(yml_conf[
'arch'], SUPPORT_MODELS))
def print_config(self):
print('----------- Model Configuration -----------')
print('%s: %s' % ('Model Arch', self.arch))
print('%s: ' % ('Transform Order'))
for op_info in self.preprocess_infos:
print('--%s: %s' % ('transform op', op_info['type']))
print('--------------------------------------------')
def load_predictor(model_dir,
arch,
run_mode='paddle',
batch_size=1,
device='CPU',
min_subgraph_size=3,
use_dynamic_shape=False,
trt_min_shape=1,
trt_max_shape=1280,
trt_opt_shape=640,
trt_calib_mode=False,
cpu_threads=1,
enable_mkldnn=False,
enable_mkldnn_bfloat16=False,
delete_shuffle_pass=False):
"""set AnalysisConfig, generate AnalysisPredictor
Args:
model_dir (str): root path of __model__ and __params__
device (str): Choose the device you want to run, it can be: CPU/GPU/XPU, default is CPU
run_mode (str): mode of running(paddle/trt_fp32/trt_fp16/trt_int8)
use_dynamic_shape (bool): use dynamic shape or not
trt_min_shape (int): min shape for dynamic shape in trt
trt_max_shape (int): max shape for dynamic shape in trt
trt_opt_shape (int): opt shape for dynamic shape in trt
trt_calib_mode (bool): If the model is produced by TRT offline quantitative
calibration, trt_calib_mode need to set True
delete_shuffle_pass (bool): whether to remove shuffle_channel_detect_pass in TensorRT.
Used by action model.
Returns:
predictor (PaddlePredictor): AnalysisPredictor
Raises:
ValueError: predict by TensorRT need device == 'GPU'.
"""
if device != 'GPU' and run_mode != 'paddle':
raise ValueError(
"Predict by TensorRT mode: {}, expect device=='GPU', but device == {}"
.format(run_mode, device))
if paddle.__version__ >= '3.0.0' or paddle.__version__ == '0.0.0':
model_path = model_dir
model_prefix = 'model'
infer_param = os.path.join(model_dir, 'model.pdiparams')
if not os.path.exists(infer_param):
if paddle.framework.use_pir_api():
infer_model = os.path.join(model_dir, 'inference.pdmodel')
else:
infer_model = os.path.join(model_dir, 'inference.json')
if not os.path.exists(infer_model):
raise ValueError(
"Cannot find any inference model in dir: {}.".format(model_dir))
config = Config(model_path, model_prefix)
else:
infer_model = os.path.join(model_dir, 'model.pdmodel')
infer_params = os.path.join(model_dir, 'model.pdiparams')
if not os.path.exists(infer_model):
infer_model = os.path.join(model_dir, 'inference.pdmodel')
infer_params = os.path.join(model_dir, 'inference.pdiparams')
if not os.path.exists(infer_model):
raise ValueError(
"Cannot find any inference model in dir: {},".format(model_dir))
config = Config(infer_model, infer_params)
if device == 'GPU':
# initial GPU memory(M), device ID
config.enable_use_gpu(200, 0)
# optimize graph and fuse op
config.switch_ir_optim(True)
else:
config.disable_gpu()
config.set_cpu_math_library_num_threads(cpu_threads)
if enable_mkldnn:
try:
# cache 10 different shapes for mkldnn to avoid memory leak
config.set_mkldnn_cache_capacity(10)
config.enable_mkldnn()
if enable_mkldnn_bfloat16:
config.enable_mkldnn_bfloat16()
except Exception as e:
print(
"The current environment does not support `mkldnn`, so disable mkldnn."
)
pass
precision_map = {
'trt_int8': Config.Precision.Int8,
'trt_fp32': Config.Precision.Float32,
'trt_fp16': Config.Precision.Half
}
if run_mode in precision_map.keys():
config.enable_tensorrt_engine(
workspace_size=(1 << 25) * batch_size,
max_batch_size=batch_size,
min_subgraph_size=min_subgraph_size,
precision_mode=precision_map[run_mode],
use_static=False,
use_calib_mode=trt_calib_mode)
if FLAGS.collect_trt_shape_info:
config.collect_shape_range_info(FLAGS.tuned_trt_shape_file)
elif os.path.exists(FLAGS.tuned_trt_shape_file):
print(f'Use dynamic shape file: '
f'{FLAGS.tuned_trt_shape_file} for TRT...')
config.enable_tuned_tensorrt_dynamic_shape(
FLAGS.tuned_trt_shape_file, True)
if use_dynamic_shape:
min_input_shape = {
'image': [batch_size, 3, trt_min_shape, trt_min_shape],
'scale_factor': [batch_size, 2]
}
max_input_shape = {
'image': [batch_size, 3, trt_max_shape, trt_max_shape],
'scale_factor': [batch_size, 2]
}
opt_input_shape = {
'image': [batch_size, 3, trt_opt_shape, trt_opt_shape],
'scale_factor': [batch_size, 2]
}
config.set_trt_dynamic_shape_info(min_input_shape, max_input_shape,
opt_input_shape)
print('trt set dynamic shape done!')
# disable print log when predict
config.disable_glog_info()
# enable shared memory
config.enable_memory_optim()
# disable feed, fetch OP, needed by zero_copy_run
config.switch_use_feed_fetch_ops(False)
if delete_shuffle_pass:
config.delete_pass("shuffle_channel_detect_pass")
predictor = create_predictor(config)
return predictor, config
def visualize(image_list, result, labels, output_dir='output/', threshold=0.5):
# visualize the predict result
if 'lanes' in result:
for idx, image_file in enumerate(image_list):
lanes = result['lanes'][idx]
img = cv2.imread(image_file)
out_file = os.path.join(output_dir, os.path.basename(image_file))
# hard code
lanes = [lane.to_array([], ) for lane in lanes]
imshow_lanes(img, lanes, out_file=out_file)
return
start_idx = 0
for idx, image_file in enumerate(image_list):
im_bboxes_num = result['boxes_num'][idx]
im_results = {}
if 'boxes' in result:
im_results['boxes'] = result['boxes'][start_idx:start_idx +
im_bboxes_num, :]
if 'masks' in result:
im_results['masks'] = result['masks'][start_idx:start_idx +
im_bboxes_num, :]
if 'segm' in result:
im_results['segm'] = result['segm'][start_idx:start_idx +
im_bboxes_num, :]
if 'label' in result:
im_results['label'] = result['label'][start_idx:start_idx +
im_bboxes_num]
if 'score' in result:
im_results['score'] = result['score'][start_idx:start_idx +
im_bboxes_num]
start_idx += im_bboxes_num
im = visualize_box_mask(
image_file, im_results, labels, threshold=threshold)
img_name = os.path.split(image_file)[-1]
if not os.path.exists(output_dir):
os.makedirs(output_dir)
out_path = os.path.join(output_dir, img_name)
im.save(out_path, quality=95)
print("save result to: " + out_path)
def print_arguments(args):
print('----------- Running Arguments -----------')
for arg, value in sorted(vars(args).items()):
print('%s: %s' % (arg, value))
print('------------------------------------------')
class Pipeline(object):
def __init__(self, model_dir):
if FLAGS.use_fd_format:
deploy_file = os.path.join(model_dir, 'inference.yml')
else:
deploy_file = os.path.join(model_dir, 'infer_cfg.yml')
with open(deploy_file) as f:
yml_conf = yaml.safe_load(f)
arch = yml_conf['arch']
detector_func = 'Detector'
if arch == 'SOLOv2':
detector_func = 'DetectorSOLOv2'
elif arch == 'PicoDet':
detector_func = 'DetectorPicoDet'
elif arch == "CLRNet":
detector_func = 'DetectorCLRNet'
self.detector = eval(detector_func)(
model_dir,
device=FLAGS.device,
run_mode=FLAGS.run_mode,
batch_size=FLAGS.batch_size,
trt_min_shape=FLAGS.trt_min_shape,
trt_max_shape=FLAGS.trt_max_shape,
trt_opt_shape=FLAGS.trt_opt_shape,
trt_calib_mode=FLAGS.trt_calib_mode,
cpu_threads=FLAGS.cpu_threads,
enable_mkldnn=FLAGS.enable_mkldnn,
enable_mkldnn_bfloat16=FLAGS.enable_mkldnn_bfloat16,
threshold=FLAGS.threshold,
output_dir=FLAGS.output_dir,
use_fd_format=FLAGS.use_fd_format)
def __call__(self, image_path):
if FLAGS.image_dir is None and image_path is not None:
assert FLAGS.batch_size == 1, "batch_size should be 1, when image_file is not None"
if isinstance(image_path, str):
image_path = [image_path]
results = self.detector.predict_image(
image_path,
visual=FLAGS.save_images)
return results
paddle.enable_static()
parser = argsparser()
FLAGS = parser.parse_args()
print_arguments(FLAGS)
FLAGS.device = 'GPU'
FLAGS.save_images = False
FLAGS.device = FLAGS.device.upper()
assert FLAGS.device in ['CPU', 'GPU', 'XPU', 'NPU', 'MLU', 'GCU'
], "device should be CPU, GPU, XPU, MLU, NPU or GCU"
assert not FLAGS.use_gpu, "use_gpu has been deprecated, please use --device"
assert not (
FLAGS.enable_mkldnn == False and FLAGS.enable_mkldnn_bfloat16 == True
), 'To enable mkldnn bfloat, please turn on both enable_mkldnn and enable_mkldnn_bfloat16'

@ -0,0 +1,227 @@
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
from scipy.special import softmax
def hard_nms(box_scores, iou_threshold, top_k=-1, candidate_size=200):
"""
Args:
box_scores (N, 5): boxes in corner-form and probabilities.
iou_threshold: intersection over union threshold.
top_k: keep top_k results. If k <= 0, keep all the results.
candidate_size: only consider the candidates with the highest scores.
Returns:
picked: a list of indexes of the kept boxes
"""
scores = box_scores[:, -1]
boxes = box_scores[:, :-1]
picked = []
indexes = np.argsort(scores)
indexes = indexes[-candidate_size:]
while len(indexes) > 0:
current = indexes[-1]
picked.append(current)
if 0 < top_k == len(picked) or len(indexes) == 1:
break
current_box = boxes[current, :]
indexes = indexes[:-1]
rest_boxes = boxes[indexes, :]
iou = iou_of(
rest_boxes,
np.expand_dims(
current_box, axis=0), )
indexes = indexes[iou <= iou_threshold]
return box_scores[picked, :]
def iou_of(boxes0, boxes1, eps=1e-5):
"""Return intersection-over-union (Jaccard index) of boxes.
Args:
boxes0 (N, 4): ground truth boxes.
boxes1 (N or 1, 4): predicted boxes.
eps: a small number to avoid 0 as denominator.
Returns:
iou (N): IoU values.
"""
overlap_left_top = np.maximum(boxes0[..., :2], boxes1[..., :2])
overlap_right_bottom = np.minimum(boxes0[..., 2:], boxes1[..., 2:])
overlap_area = area_of(overlap_left_top, overlap_right_bottom)
area0 = area_of(boxes0[..., :2], boxes0[..., 2:])
area1 = area_of(boxes1[..., :2], boxes1[..., 2:])
return overlap_area / (area0 + area1 - overlap_area + eps)
def area_of(left_top, right_bottom):
"""Compute the areas of rectangles given two corners.
Args:
left_top (N, 2): left top corner.
right_bottom (N, 2): right bottom corner.
Returns:
area (N): return the area.
"""
hw = np.clip(right_bottom - left_top, 0.0, None)
return hw[..., 0] * hw[..., 1]
class PicoDetPostProcess(object):
"""
Args:
input_shape (int): network input image size
ori_shape (int): ori image shape of before padding
scale_factor (float): scale factor of ori image
enable_mkldnn (bool): whether to open MKLDNN
"""
def __init__(self,
input_shape,
ori_shape,
scale_factor,
strides=[8, 16, 32, 64],
score_threshold=0.4,
nms_threshold=0.5,
nms_top_k=1000,
keep_top_k=100):
self.ori_shape = ori_shape
self.input_shape = input_shape
self.scale_factor = scale_factor
self.strides = strides
self.score_threshold = score_threshold
self.nms_threshold = nms_threshold
self.nms_top_k = nms_top_k
self.keep_top_k = keep_top_k
def warp_boxes(self, boxes, ori_shape):
"""Apply transform to boxes
"""
width, height = ori_shape[1], ori_shape[0]
n = len(boxes)
if n:
# warp points
xy = np.ones((n * 4, 3))
xy[:, :2] = boxes[:, [0, 1, 2, 3, 0, 3, 2, 1]].reshape(
n * 4, 2) # x1y1, x2y2, x1y2, x2y1
# xy = xy @ M.T # transform
xy = (xy[:, :2] / xy[:, 2:3]).reshape(n, 8) # rescale
# create new boxes
x = xy[:, [0, 2, 4, 6]]
y = xy[:, [1, 3, 5, 7]]
xy = np.concatenate(
(x.min(1), y.min(1), x.max(1), y.max(1))).reshape(4, n).T
# clip boxes
xy[:, [0, 2]] = xy[:, [0, 2]].clip(0, width)
xy[:, [1, 3]] = xy[:, [1, 3]].clip(0, height)
return xy.astype(np.float32)
else:
return boxes
def __call__(self, scores, raw_boxes):
batch_size = raw_boxes[0].shape[0]
reg_max = int(raw_boxes[0].shape[-1] / 4 - 1)
out_boxes_num = []
out_boxes_list = []
for batch_id in range(batch_size):
# generate centers
decode_boxes = []
select_scores = []
for stride, box_distribute, score in zip(self.strides, raw_boxes,
scores):
box_distribute = box_distribute[batch_id]
score = score[batch_id]
# centers
fm_h = self.input_shape[0] / stride
fm_w = self.input_shape[1] / stride
h_range = np.arange(fm_h)
w_range = np.arange(fm_w)
ww, hh = np.meshgrid(w_range, h_range)
ct_row = (hh.flatten() + 0.5) * stride
ct_col = (ww.flatten() + 0.5) * stride
center = np.stack((ct_col, ct_row, ct_col, ct_row), axis=1)
# box distribution to distance
reg_range = np.arange(reg_max + 1)
box_distance = box_distribute.reshape((-1, reg_max + 1))
box_distance = softmax(box_distance, axis=1)
box_distance = box_distance * np.expand_dims(reg_range, axis=0)
box_distance = np.sum(box_distance, axis=1).reshape((-1, 4))
box_distance = box_distance * stride
# top K candidate
topk_idx = np.argsort(score.max(axis=1))[::-1]
topk_idx = topk_idx[:self.nms_top_k]
center = center[topk_idx]
score = score[topk_idx]
box_distance = box_distance[topk_idx]
# decode box
decode_box = center + [-1, -1, 1, 1] * box_distance
select_scores.append(score)
decode_boxes.append(decode_box)
# nms
bboxes = np.concatenate(decode_boxes, axis=0)
confidences = np.concatenate(select_scores, axis=0)
picked_box_probs = []
picked_labels = []
for class_index in range(0, confidences.shape[1]):
probs = confidences[:, class_index]
mask = probs > self.score_threshold
probs = probs[mask]
if probs.shape[0] == 0:
continue
subset_boxes = bboxes[mask, :]
box_probs = np.concatenate(
[subset_boxes, probs.reshape(-1, 1)], axis=1)
box_probs = hard_nms(
box_probs,
iou_threshold=self.nms_threshold,
top_k=self.keep_top_k, )
picked_box_probs.append(box_probs)
picked_labels.extend([class_index] * box_probs.shape[0])
if len(picked_box_probs) == 0:
out_boxes_list.append(np.empty((0, 4)))
out_boxes_num.append(0)
else:
picked_box_probs = np.concatenate(picked_box_probs)
# resize output boxes
picked_box_probs[:, :4] = self.warp_boxes(
picked_box_probs[:, :4], self.ori_shape[batch_id])
im_scale = np.concatenate([
self.scale_factor[batch_id][::-1],
self.scale_factor[batch_id][::-1]
])
picked_box_probs[:, :4] /= im_scale
# clas score box
out_boxes_list.append(
np.concatenate(
[
np.expand_dims(
np.array(picked_labels),
axis=-1), np.expand_dims(
picked_box_probs[:, 4], axis=-1),
picked_box_probs[:, :4]
],
axis=1))
out_boxes_num.append(len(picked_labels))
out_boxes_list = np.concatenate(out_boxes_list, axis=0)
out_boxes_num = np.asarray(out_boxes_num).astype(np.int32)
return out_boxes_list, out_boxes_num

@ -0,0 +1,549 @@
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import cv2
import numpy as np
import imgaug.augmenters as iaa
from keypoint_preprocess import get_affine_transform
from PIL import Image
def decode_image(im_file, im_info):
"""read rgb image
Args:
im_file (str|np.ndarray): input can be image path or np.ndarray
im_info (dict): info of image
Returns:
im (np.ndarray): processed image (np.ndarray)
im_info (dict): info of processed image
"""
if isinstance(im_file, str):
with open(im_file, 'rb') as f:
im_read = f.read()
data = np.frombuffer(im_read, dtype='uint8')
im = cv2.imdecode(data, 1) # BGR mode, but need RGB mode
im = cv2.cvtColor(im, cv2.COLOR_BGR2RGB)
else:
im = im_file
im_info['im_shape'] = np.array(im.shape[:2], dtype=np.float32)
im_info['scale_factor'] = np.array([1., 1.], dtype=np.float32)
return im, im_info
class Resize_Mult32(object):
"""resize image by target_size and max_size
Args:
target_size (int): the target size of image
keep_ratio (bool): whether keep_ratio or not, default true
interp (int): method of resize
"""
def __init__(self, limit_side_len, limit_type, interp=cv2.INTER_LINEAR):
self.limit_side_len = limit_side_len
self.limit_type = limit_type
self.interp = interp
def __call__(self, im, im_info):
"""
Args:
im (np.ndarray): image (np.ndarray)
im_info (dict): info of image
Returns:
im (np.ndarray): processed image (np.ndarray)
im_info (dict): info of processed image
"""
im_channel = im.shape[2]
im_scale_y, im_scale_x = self.generate_scale(im)
im = cv2.resize(
im,
None,
None,
fx=im_scale_x,
fy=im_scale_y,
interpolation=self.interp)
im_info['im_shape'] = np.array(im.shape[:2]).astype('float32')
im_info['scale_factor'] = np.array(
[im_scale_y, im_scale_x]).astype('float32')
return im, im_info
def generate_scale(self, img):
"""
Args:
img (np.ndarray): image (np.ndarray)
Returns:
im_scale_x: the resize ratio of X
im_scale_y: the resize ratio of Y
"""
limit_side_len = self.limit_side_len
h, w, c = img.shape
# limit the max side
if self.limit_type == 'max':
if h > w:
ratio = float(limit_side_len) / h
else:
ratio = float(limit_side_len) / w
elif self.limit_type == 'min':
if h < w:
ratio = float(limit_side_len) / h
else:
ratio = float(limit_side_len) / w
elif self.limit_type == 'resize_long':
ratio = float(limit_side_len) / max(h, w)
else:
raise Exception('not support limit type, image ')
resize_h = int(h * ratio)
resize_w = int(w * ratio)
resize_h = max(int(round(resize_h / 32) * 32), 32)
resize_w = max(int(round(resize_w / 32) * 32), 32)
im_scale_y = resize_h / float(h)
im_scale_x = resize_w / float(w)
return im_scale_y, im_scale_x
class Resize(object):
"""resize image by target_size and max_size
Args:
target_size (int): the target size of image
keep_ratio (bool): whether keep_ratio or not, default true
interp (int): method of resize
"""
def __init__(self, target_size, keep_ratio=True, interp=cv2.INTER_LINEAR):
if isinstance(target_size, int):
target_size = [target_size, target_size]
self.target_size = target_size
self.keep_ratio = keep_ratio
self.interp = interp
def __call__(self, im, im_info):
"""
Args:
im (np.ndarray): image (np.ndarray)
im_info (dict): info of image
Returns:
im (np.ndarray): processed image (np.ndarray)
im_info (dict): info of processed image
"""
assert len(self.target_size) == 2
assert self.target_size[0] > 0 and self.target_size[1] > 0
im_channel = im.shape[2]
im_scale_y, im_scale_x = self.generate_scale(im)
im = cv2.resize(
im,
None,
None,
fx=im_scale_x,
fy=im_scale_y,
interpolation=self.interp)
im_info['im_shape'] = np.array(im.shape[:2]).astype('float32')
im_info['scale_factor'] = np.array(
[im_scale_y, im_scale_x]).astype('float32')
return im, im_info
def generate_scale(self, im):
"""
Args:
im (np.ndarray): image (np.ndarray)
Returns:
im_scale_x: the resize ratio of X
im_scale_y: the resize ratio of Y
"""
origin_shape = im.shape[:2]
im_c = im.shape[2]
if self.keep_ratio:
im_size_min = np.min(origin_shape)
im_size_max = np.max(origin_shape)
target_size_min = np.min(self.target_size)
target_size_max = np.max(self.target_size)
im_scale = float(target_size_min) / float(im_size_min)
if np.round(im_scale * im_size_max) > target_size_max:
im_scale = float(target_size_max) / float(im_size_max)
im_scale_x = im_scale
im_scale_y = im_scale
else:
resize_h, resize_w = self.target_size
im_scale_y = resize_h / float(origin_shape[0])
im_scale_x = resize_w / float(origin_shape[1])
return im_scale_y, im_scale_x
class ShortSizeScale(object):
"""
Scale images by short size.
Args:
short_size(float | int): Short size of an image will be scaled to the short_size.
fixed_ratio(bool): Set whether to zoom according to a fixed ratio. default: True
do_round(bool): Whether to round up when calculating the zoom ratio. default: False
backend(str): Choose pillow or cv2 as the graphics processing backend. default: 'pillow'
"""
def __init__(self,
short_size,
fixed_ratio=True,
keep_ratio=None,
do_round=False,
backend='pillow'):
self.short_size = short_size
assert (fixed_ratio and not keep_ratio) or (
not fixed_ratio
), "fixed_ratio and keep_ratio cannot be true at the same time"
self.fixed_ratio = fixed_ratio
self.keep_ratio = keep_ratio
self.do_round = do_round
assert backend in [
'pillow', 'cv2'
], "Scale's backend must be pillow or cv2, but get {backend}"
self.backend = backend
def __call__(self, img):
"""
Performs resize operations.
Args:
img (PIL.Image): a PIL.Image.
return:
resized_img: a PIL.Image after scaling.
"""
result_img = None
if isinstance(img, np.ndarray):
h, w, _ = img.shape
elif isinstance(img, Image.Image):
w, h = img.size
else:
raise NotImplementedError
if w <= h:
ow = self.short_size
if self.fixed_ratio: # default is True
oh = int(self.short_size * 4.0 / 3.0)
elif not self.keep_ratio: # no
oh = self.short_size
else:
scale_factor = self.short_size / w
oh = int(h * float(scale_factor) +
0.5) if self.do_round else int(h * self.short_size / w)
ow = int(w * float(scale_factor) +
0.5) if self.do_round else int(w * self.short_size / h)
else:
oh = self.short_size
if self.fixed_ratio:
ow = int(self.short_size * 4.0 / 3.0)
elif not self.keep_ratio: # no
ow = self.short_size
else:
scale_factor = self.short_size / h
oh = int(h * float(scale_factor) +
0.5) if self.do_round else int(h * self.short_size / w)
ow = int(w * float(scale_factor) +
0.5) if self.do_round else int(w * self.short_size / h)
if type(img) == np.ndarray:
img = Image.fromarray(img, mode='RGB')
if self.backend == 'pillow':
result_img = img.resize((ow, oh), Image.BILINEAR)
elif self.backend == 'cv2' and (self.keep_ratio is not None):
result_img = cv2.resize(
img, (ow, oh), interpolation=cv2.INTER_LINEAR)
else:
result_img = Image.fromarray(
cv2.resize(
np.asarray(img), (ow, oh), interpolation=cv2.INTER_LINEAR))
return result_img
class NormalizeImage(object):
"""normalize image
Args:
mean (list): im - mean
std (list): im / std
is_scale (bool): whether need im / 255
norm_type (str): type in ['mean_std', 'none']
"""
def __init__(self, mean, std, is_scale=True, norm_type='mean_std'):
self.mean = mean
self.std = std
self.is_scale = is_scale
self.norm_type = norm_type
def __call__(self, im, im_info):
"""
Args:
im (np.ndarray): image (np.ndarray)
im_info (dict): info of image
Returns:
im (np.ndarray): processed image (np.ndarray)
im_info (dict): info of processed image
"""
im = im.astype(np.float32, copy=False)
if self.is_scale:
scale = 1.0 / 255.0
im *= scale
if self.norm_type == 'mean_std':
mean = np.array(self.mean)[np.newaxis, np.newaxis, :]
std = np.array(self.std)[np.newaxis, np.newaxis, :]
im -= mean
im /= std
return im, im_info
class Permute(object):
"""permute image
Args:
to_bgr (bool): whether convert RGB to BGR
channel_first (bool): whether convert HWC to CHW
"""
def __init__(self, ):
super(Permute, self).__init__()
def __call__(self, im, im_info):
"""
Args:
im (np.ndarray): image (np.ndarray)
im_info (dict): info of image
Returns:
im (np.ndarray): processed image (np.ndarray)
im_info (dict): info of processed image
"""
im = im.transpose((2, 0, 1)).copy()
return im, im_info
class PadStride(object):
""" padding image for model with FPN, instead PadBatch(pad_to_stride) in original config
Args:
stride (bool): model with FPN need image shape % stride == 0
"""
def __init__(self, stride=0):
self.coarsest_stride = stride
def __call__(self, im, im_info):
"""
Args:
im (np.ndarray): image (np.ndarray)
im_info (dict): info of image
Returns:
im (np.ndarray): processed image (np.ndarray)
im_info (dict): info of processed image
"""
coarsest_stride = self.coarsest_stride
if coarsest_stride <= 0:
return im, im_info
im_c, im_h, im_w = im.shape
pad_h = int(np.ceil(float(im_h) / coarsest_stride) * coarsest_stride)
pad_w = int(np.ceil(float(im_w) / coarsest_stride) * coarsest_stride)
padding_im = np.zeros((im_c, pad_h, pad_w), dtype=np.float32)
padding_im[:, :im_h, :im_w] = im
return padding_im, im_info
class LetterBoxResize(object):
def __init__(self, target_size):
"""
Resize image to target size, convert normalized xywh to pixel xyxy
format ([x_center, y_center, width, height] -> [x0, y0, x1, y1]).
Args:
target_size (int|list): image target size.
"""
super(LetterBoxResize, self).__init__()
if isinstance(target_size, int):
target_size = [target_size, target_size]
self.target_size = target_size
def letterbox(self, img, height, width, color=(127.5, 127.5, 127.5)):
# letterbox: resize a rectangular image to a padded rectangular
shape = img.shape[:2] # [height, width]
ratio_h = float(height) / shape[0]
ratio_w = float(width) / shape[1]
ratio = min(ratio_h, ratio_w)
new_shape = (round(shape[1] * ratio),
round(shape[0] * ratio)) # [width, height]
padw = (width - new_shape[0]) / 2
padh = (height - new_shape[1]) / 2
top, bottom = round(padh - 0.1), round(padh + 0.1)
left, right = round(padw - 0.1), round(padw + 0.1)
img = cv2.resize(
img, new_shape, interpolation=cv2.INTER_AREA) # resized, no border
img = cv2.copyMakeBorder(
img, top, bottom, left, right, cv2.BORDER_CONSTANT,
value=color) # padded rectangular
return img, ratio, padw, padh
def __call__(self, im, im_info):
"""
Args:
im (np.ndarray): image (np.ndarray)
im_info (dict): info of image
Returns:
im (np.ndarray): processed image (np.ndarray)
im_info (dict): info of processed image
"""
assert len(self.target_size) == 2
assert self.target_size[0] > 0 and self.target_size[1] > 0
height, width = self.target_size
h, w = im.shape[:2]
im, ratio, padw, padh = self.letterbox(im, height=height, width=width)
new_shape = [round(h * ratio), round(w * ratio)]
im_info['im_shape'] = np.array(new_shape, dtype=np.float32)
im_info['scale_factor'] = np.array([ratio, ratio], dtype=np.float32)
return im, im_info
class Pad(object):
def __init__(self, size, fill_value=[114.0, 114.0, 114.0]):
"""
Pad image to a specified size.
Args:
size (list[int]): image target size
fill_value (list[float]): rgb value of pad area, default (114.0, 114.0, 114.0)
"""
super(Pad, self).__init__()
if isinstance(size, int):
size = [size, size]
self.size = size
self.fill_value = fill_value
def __call__(self, im, im_info):
im_h, im_w = im.shape[:2]
h, w = self.size
if h == im_h and w == im_w:
im = im.astype(np.float32)
return im, im_info
canvas = np.ones((h, w, 3), dtype=np.float32)
canvas *= np.array(self.fill_value, dtype=np.float32)
canvas[0:im_h, 0:im_w, :] = im.astype(np.float32)
im = canvas
return im, im_info
class WarpAffine(object):
"""Warp affine the image
"""
def __init__(self,
keep_res=False,
pad=31,
input_h=512,
input_w=512,
scale=0.4,
shift=0.1,
down_ratio=4):
self.keep_res = keep_res
self.pad = pad
self.input_h = input_h
self.input_w = input_w
self.scale = scale
self.shift = shift
self.down_ratio = down_ratio
def __call__(self, im, im_info):
"""
Args:
im (np.ndarray): image (np.ndarray)
im_info (dict): info of image
Returns:
im (np.ndarray): processed image (np.ndarray)
im_info (dict): info of processed image
"""
img = cv2.cvtColor(im, cv2.COLOR_RGB2BGR)
h, w = img.shape[:2]
if self.keep_res:
# True in detection eval/infer
input_h = (h | self.pad) + 1
input_w = (w | self.pad) + 1
s = np.array([input_w, input_h], dtype=np.float32)
c = np.array([w // 2, h // 2], dtype=np.float32)
else:
# False in centertrack eval_mot/eval_mot
s = max(h, w) * 1.0
input_h, input_w = self.input_h, self.input_w
c = np.array([w / 2., h / 2.], dtype=np.float32)
trans_input = get_affine_transform(c, s, 0, [input_w, input_h])
img = cv2.resize(img, (w, h))
inp = cv2.warpAffine(
img, trans_input, (input_w, input_h), flags=cv2.INTER_LINEAR)
if not self.keep_res:
out_h = input_h // self.down_ratio
out_w = input_w // self.down_ratio
trans_output = get_affine_transform(c, s, 0, [out_w, out_h])
im_info.update({
'center': c,
'scale': s,
'out_height': out_h,
'out_width': out_w,
'inp_height': input_h,
'inp_width': input_w,
'trans_input': trans_input,
'trans_output': trans_output,
})
return inp, im_info
class CULaneResize(object):
def __init__(self, img_h, img_w, cut_height, prob=0.5):
super(CULaneResize, self).__init__()
self.img_h = img_h
self.img_w = img_w
self.cut_height = cut_height
self.prob = prob
def __call__(self, im, im_info):
# cut
im = im[self.cut_height:, :, :]
# resize
transform = iaa.Sometimes(self.prob,
iaa.Resize({
"height": self.img_h,
"width": self.img_w
}))
im = transform(image=im.copy().astype(np.uint8))
im = im.astype(np.float32) / 255.
# check transpose is need whether the func decode_image is equal to CULaneDataSet cv.imread
im = im.transpose(2, 0, 1)
return im, im_info
def preprocess(im, preprocess_ops):
# process image by preprocess_ops
im_info = {
'scale_factor': np.array(
[1., 1.], dtype=np.float32),
'im_shape': None,
}
im, im_info = decode_image(im, im_info)
for operator in preprocess_ops:
im, im_info = operator(im, im_info)
return im, im_info

@ -0,0 +1,6 @@
export FLAGS_enable_pir_api=0
python3 pdf_detection.py \
--model_dir=/mnt/research/PaddleOCR/PaddleDetection/inference_model/picodet_lcnet_x1_0_fgd_layout_cdla_infer \
--image_file=/mnt/research/PaddleOCR/demo-75-images/12.jpg \
--device=GPU

@ -0,0 +1,629 @@
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import time
import os
import ast
import argparse
from typing import List, Tuple
import numpy as np
class LayoutBox(object):
def __init__(self, clsid: int, pos: List[float], confidence: float):
self.clsid = clsid
self.pos = pos
self.confidence = confidence
class PageDetectionResult(object):
def __init__(self, boxes: List[LayoutBox], image_path: str):
self.boxes = boxes
self.image_path = image_path
def argsparser():
parser = argparse.ArgumentParser(description=__doc__)
parser.add_argument(
"--image_dir",
type=str,
default=None,
help="Dir of image file, `image_file` has a higher priority.")
parser.add_argument(
"--batch_size", type=int, default=1, help="batch_size for inference.")
parser.add_argument(
"--video_file",
type=str,
default=None,
help="Path of video file, `video_file` or `camera_id` has a highest priority."
)
parser.add_argument(
"--camera_id",
type=int,
default=-1,
help="device id of camera to predict.")
parser.add_argument(
"--threshold", type=float, default=0.5, help="Threshold of score.")
parser.add_argument(
"--output_dir",
type=str,
default="output",
help="Directory of output visualization files.")
parser.add_argument(
"--run_mode",
type=str,
default='paddle',
help="mode of running(paddle/trt_fp32/trt_fp16/trt_int8)")
parser.add_argument(
"--device",
type=str,
default='cpu',
help="Choose the device you want to run, it can be: CPU/GPU/XPU/NPU, default is CPU."
)
parser.add_argument(
"--use_gpu",
type=ast.literal_eval,
default=False,
help="Deprecated, please use `--device`.")
parser.add_argument(
"--run_benchmark",
type=ast.literal_eval,
default=False,
help="Whether to predict a image_file repeatedly for benchmark")
parser.add_argument(
"--enable_mkldnn",
type=ast.literal_eval,
default=False,
help="Whether use mkldnn with CPU.")
parser.add_argument(
"--enable_mkldnn_bfloat16",
type=ast.literal_eval,
default=False,
help="Whether use mkldnn bfloat16 inference with CPU.")
parser.add_argument(
"--cpu_threads", type=int, default=1, help="Num of threads with CPU.")
parser.add_argument(
"--trt_min_shape", type=int, default=1, help="min_shape for TensorRT.")
parser.add_argument(
"--trt_max_shape",
type=int,
default=1280,
help="max_shape for TensorRT.")
parser.add_argument(
"--trt_opt_shape",
type=int,
default=640,
help="opt_shape for TensorRT.")
parser.add_argument(
"--trt_calib_mode",
type=bool,
default=False,
help="If the model is produced by TRT offline quantitative "
"calibration, trt_calib_mode need to set True.")
parser.add_argument(
'--save_images',
type=ast.literal_eval,
default=True,
help='Save visualization image results.')
parser.add_argument(
'--save_mot_txts',
action='store_true',
help='Save tracking results (txt).')
parser.add_argument(
'--save_mot_txt_per_img',
action='store_true',
help='Save tracking results (txt) for each image.')
parser.add_argument(
'--scaled',
type=bool,
default=False,
help="Whether coords after detector outputs are scaled, False in JDE YOLOv3 "
"True in general detector.")
parser.add_argument(
"--tracker_config", type=str, default=None, help=("tracker donfig"))
parser.add_argument(
"--reid_model_dir",
type=str,
default=None,
help=("Directory include:'model.pdiparams', 'model.pdmodel', "
"'infer_cfg.yml', created by tools/export_model.py."))
parser.add_argument(
"--reid_batch_size",
type=int,
default=50,
help="max batch_size for reid model inference.")
parser.add_argument(
'--use_dark',
type=ast.literal_eval,
default=True,
help='whether to use darkpose to get better keypoint position predict ')
parser.add_argument(
"--action_file",
type=str,
default=None,
help="Path of input file for action recognition.")
parser.add_argument(
"--window_size",
type=int,
default=50,
help="Temporal size of skeleton feature for action recognition.")
parser.add_argument(
"--random_pad",
type=ast.literal_eval,
default=False,
help="Whether do random padding for action recognition.")
parser.add_argument(
"--save_results",
action='store_true',
default=False,
help="Whether save detection result to file using coco format")
parser.add_argument(
'--use_coco_category',
action='store_true',
default=False,
help='Whether to use the coco format dictionary `clsid2catid`')
parser.add_argument(
"--slice_infer",
action='store_true',
help="Whether to slice the image and merge the inference results for small object detection."
)
parser.add_argument(
'--slice_size',
nargs='+',
type=int,
default=[640, 640],
help="Height of the sliced image.")
parser.add_argument(
"--overlap_ratio",
nargs='+',
type=float,
default=[0.25, 0.25],
help="Overlap height ratio of the sliced image.")
parser.add_argument(
"--combine_method",
type=str,
default='nms',
help="Combine method of the sliced images' detection results, choose in ['nms', 'nmm', 'concat']."
)
parser.add_argument(
"--match_threshold",
type=float,
default=0.6,
help="Combine method matching threshold.")
parser.add_argument(
"--match_metric",
type=str,
default='ios',
help="Combine method matching metric, choose in ['iou', 'ios'].")
parser.add_argument(
"--collect_trt_shape_info",
action='store_true',
default=False,
help="Whether to collect dynamic shape before using tensorrt.")
parser.add_argument(
"--tuned_trt_shape_file",
type=str,
default="shape_range_info.pbtxt",
help="Path of a dynamic shape file for tensorrt.")
parser.add_argument("--use_fd_format", action="store_true")
parser.add_argument(
"--task_type",
type=str,
default='Detection',
help="How to save the coco result, it only work with save_results==True. Optional inputs are Rotate or Detection, default is Detection."
)
return parser
class Times(object):
def __init__(self):
self.time = 0.
# start time
self.st = 0.
# end time
self.et = 0.
def start(self):
self.st = time.time()
def end(self, repeats=1, accumulative=True):
self.et = time.time()
if accumulative:
self.time += (self.et - self.st) / repeats
else:
self.time = (self.et - self.st) / repeats
def reset(self):
self.time = 0.
self.st = 0.
self.et = 0.
def value(self):
return round(self.time, 4)
class Timer(Times):
def __init__(self, with_tracker=False):
super(Timer, self).__init__()
self.with_tracker = with_tracker
self.preprocess_time_s = Times()
self.inference_time_s = Times()
self.postprocess_time_s = Times()
self.tracking_time_s = Times()
self.img_num = 0
def info(self, average=False):
pre_time = self.preprocess_time_s.value()
infer_time = self.inference_time_s.value()
post_time = self.postprocess_time_s.value()
track_time = self.tracking_time_s.value()
total_time = pre_time + infer_time + post_time
if self.with_tracker:
total_time = total_time + track_time
total_time = round(total_time, 4)
print("------------------ Inference Time Info ----------------------")
print("total_time(ms): {}, img_num: {}".format(total_time * 1000,
self.img_num))
preprocess_time = round(pre_time / max(1, self.img_num),
4) if average else pre_time
postprocess_time = round(post_time / max(1, self.img_num),
4) if average else post_time
inference_time = round(infer_time / max(1, self.img_num),
4) if average else infer_time
tracking_time = round(track_time / max(1, self.img_num),
4) if average else track_time
average_latency = total_time / max(1, self.img_num)
qps = 0
if total_time > 0:
qps = 1 / average_latency
print("average latency time(ms): {:.2f}, QPS: {:2f}".format(
average_latency * 1000, qps))
if self.with_tracker:
print(
"preprocess_time(ms): {:.2f}, inference_time(ms): {:.2f}, postprocess_time(ms): {:.2f}, tracking_time(ms): {:.2f}".
format(preprocess_time * 1000, inference_time * 1000,
postprocess_time * 1000, tracking_time * 1000))
else:
print(
"preprocess_time(ms): {:.2f}, inference_time(ms): {:.2f}, postprocess_time(ms): {:.2f}".
format(preprocess_time * 1000, inference_time * 1000,
postprocess_time * 1000))
def report(self, average=False):
dic = {}
pre_time = self.preprocess_time_s.value()
infer_time = self.inference_time_s.value()
post_time = self.postprocess_time_s.value()
track_time = self.tracking_time_s.value()
dic['preprocess_time_s'] = round(pre_time / max(1, self.img_num),
4) if average else pre_time
dic['inference_time_s'] = round(infer_time / max(1, self.img_num),
4) if average else infer_time
dic['postprocess_time_s'] = round(post_time / max(1, self.img_num),
4) if average else post_time
dic['img_num'] = self.img_num
total_time = pre_time + infer_time + post_time
if self.with_tracker:
dic['tracking_time_s'] = round(track_time / max(1, self.img_num),
4) if average else track_time
total_time = total_time + track_time
dic['total_time_s'] = round(total_time, 4)
return dic
def multiclass_nms(bboxs, num_classes, match_threshold=0.6, match_metric='iou'):
final_boxes = []
for c in range(num_classes):
idxs = bboxs[:, 0] == c
if np.count_nonzero(idxs) == 0: continue
r = nms(bboxs[idxs, 1:], match_threshold, match_metric)
final_boxes.append(np.concatenate([np.full((r.shape[0], 1), c), r], 1))
return final_boxes
def nms(dets, match_threshold=0.6, match_metric='iou'):
""" Apply NMS to avoid detecting too many overlapping bounding boxes.
Args:
dets: shape [N, 5], [score, x1, y1, x2, y2]
match_metric: 'iou' or 'ios'
match_threshold: overlap thresh for match metric.
"""
if dets.shape[0] == 0:
return dets[[], :]
scores = dets[:, 0]
x1 = dets[:, 1]
y1 = dets[:, 2]
x2 = dets[:, 3]
y2 = dets[:, 4]
areas = (x2 - x1 + 1) * (y2 - y1 + 1)
order = scores.argsort()[::-1]
ndets = dets.shape[0]
suppressed = np.zeros((ndets), dtype=np.int32)
for _i in range(ndets):
i = order[_i]
if suppressed[i] == 1:
continue
ix1 = x1[i]
iy1 = y1[i]
ix2 = x2[i]
iy2 = y2[i]
iarea = areas[i]
for _j in range(_i + 1, ndets):
j = order[_j]
if suppressed[j] == 1:
continue
xx1 = max(ix1, x1[j])
yy1 = max(iy1, y1[j])
xx2 = min(ix2, x2[j])
yy2 = min(iy2, y2[j])
w = max(0.0, xx2 - xx1 + 1)
h = max(0.0, yy2 - yy1 + 1)
inter = w * h
if match_metric == 'iou':
union = iarea + areas[j] - inter
match_value = inter / union
elif match_metric == 'ios':
smaller = min(iarea, areas[j])
match_value = inter / smaller
else:
raise ValueError()
if match_value >= match_threshold:
suppressed[j] = 1
keep = np.where(suppressed == 0)[0]
dets = dets[keep, :]
return dets
coco_clsid2catid = {
0: 1,
1: 2,
2: 3,
3: 4,
4: 5,
5: 6,
6: 7,
7: 8,
8: 9,
9: 10,
10: 11,
11: 13,
12: 14,
13: 15,
14: 16,
15: 17,
16: 18,
17: 19,
18: 20,
19: 21,
20: 22,
21: 23,
22: 24,
23: 25,
24: 27,
25: 28,
26: 31,
27: 32,
28: 33,
29: 34,
30: 35,
31: 36,
32: 37,
33: 38,
34: 39,
35: 40,
36: 41,
37: 42,
38: 43,
39: 44,
40: 46,
41: 47,
42: 48,
43: 49,
44: 50,
45: 51,
46: 52,
47: 53,
48: 54,
49: 55,
50: 56,
51: 57,
52: 58,
53: 59,
54: 60,
55: 61,
56: 62,
57: 63,
58: 64,
59: 65,
60: 67,
61: 70,
62: 72,
63: 73,
64: 74,
65: 75,
66: 76,
67: 77,
68: 78,
69: 79,
70: 80,
71: 81,
72: 82,
73: 84,
74: 85,
75: 86,
76: 87,
77: 88,
78: 89,
79: 90
}
def gaussian_radius(bbox_size, min_overlap):
height, width = bbox_size
a1 = 1
b1 = (height + width)
c1 = width * height * (1 - min_overlap) / (1 + min_overlap)
sq1 = np.sqrt(b1**2 - 4 * a1 * c1)
radius1 = (b1 + sq1) / (2 * a1)
a2 = 4
b2 = 2 * (height + width)
c2 = (1 - min_overlap) * width * height
sq2 = np.sqrt(b2**2 - 4 * a2 * c2)
radius2 = (b2 + sq2) / 2
a3 = 4 * min_overlap
b3 = -2 * min_overlap * (height + width)
c3 = (min_overlap - 1) * width * height
sq3 = np.sqrt(b3**2 - 4 * a3 * c3)
radius3 = (b3 + sq3) / 2
return min(radius1, radius2, radius3)
def gaussian2D(shape, sigma_x=1, sigma_y=1):
m, n = [(ss - 1.) / 2. for ss in shape]
y, x = np.ogrid[-m:m + 1, -n:n + 1]
h = np.exp(-(x * x / (2 * sigma_x * sigma_x) + y * y / (2 * sigma_y *
sigma_y)))
h[h < np.finfo(h.dtype).eps * h.max()] = 0
return h
def draw_umich_gaussian(heatmap, center, radius, k=1):
"""
draw_umich_gaussian, refer to https://github.com/xingyizhou/CenterNet/blob/master/src/lib/utils/image.py#L126
"""
diameter = 2 * radius + 1
gaussian = gaussian2D(
(diameter, diameter), sigma_x=diameter / 6, sigma_y=diameter / 6)
x, y = int(center[0]), int(center[1])
height, width = heatmap.shape[0:2]
left, right = min(x, radius), min(width - x, radius + 1)
top, bottom = min(y, radius), min(height - y, radius + 1)
masked_heatmap = heatmap[y - top:y + bottom, x - left:x + right]
masked_gaussian = gaussian[radius - top:radius + bottom, radius - left:
radius + right]
if min(masked_gaussian.shape) > 0 and min(masked_heatmap.shape) > 0:
np.maximum(masked_heatmap, masked_gaussian * k, out=masked_heatmap)
return heatmap
def iou(box1, box2):
"""计算两个框的 IoU交并比"""
x1 = max(box1[0], box2[0])
y1 = max(box1[1], box2[1])
x2 = min(box1[2], box2[2])
y2 = min(box1[3], box2[3])
inter_area = max(0, x2 - x1) * max(0, y2 - y1)
box1_area = (box1[2] - box1[0]) * (box1[3] - box1[1])
box2_area = (box2[2] - box2[0]) * (box2[3] - box2[1])
union_area = box1_area + box2_area - inter_area
return inter_area / union_area if union_area != 0 else 0
def non_max_suppression(boxes, scores, iou_threshold):
"""非极大值抑制"""
if not boxes:
return []
indices = np.argsort(scores)[::-1]
selected_boxes = []
while len(indices) > 0:
current = indices[0]
selected_boxes.append(current)
remaining = indices[1:]
filtered_indices = []
for i in remaining:
if iou(boxes[current], boxes[i]) <= iou_threshold:
filtered_indices.append(i)
indices = np.array(filtered_indices)
return selected_boxes
def is_box_inside(inner, outer):
"""判断inner box是否完全在outer box内"""
return (outer[0] <= inner[0] and outer[1] <= inner[1] and
outer[2] >= inner[2] and outer[3] >= inner[3])
def boxes_overlap(box1: List[float], box2: List[float]) -> bool:
"""判断两个框是否有交集"""
x1_max = max(box1[0], box2[0])
y1_max = max(box1[1], box2[1])
x2_min = min(box1[2], box2[2])
y2_min = min(box1[3], box2[3])
return x1_max < x2_min and y1_max < y2_min
def merge_boxes(boxes: List[List[float]]) -> List[float]:
x1 = min(box[0] for box in boxes)
y1 = min(box[1] for box in boxes)
x2 = max(box[2] for box in boxes)
y2 = max(box[3] for box in boxes)
return [x1, y1, x2, y2]
def merge_text_and_title_boxes(data: List[LayoutBox], merged_labels: Tuple[int]) -> List[LayoutBox]:
text_title_boxes = [(i, box) for i, box in enumerate(data) if box.clsid in merged_labels]
other_boxes = [box.pos for box in data if box.clsid in (2, 4, 5)]
text_title_boxes.sort(key=lambda x: x[1].pos[1]) # sort by y1
merged = []
skip_indices = set()
i = 0
while i < len(text_title_boxes):
if i in skip_indices:
i += 1
continue
current_group = [text_title_boxes[i][1].pos]
group_confidences = [text_title_boxes[i][1].confidence]
j = i + 1
while j < len(text_title_boxes):
candidate_box = text_title_boxes[j][1].pos
tentative_merge = merge_boxes(current_group + [candidate_box])
has_intruder = any(boxes_overlap(other, tentative_merge) for other in other_boxes)
if has_intruder:
break
else:
current_group.append(candidate_box)
group_confidences.append(text_title_boxes[j][1].confidence)
skip_indices.add(j)
j += 1
if len(current_group) > 1:
merged_box = LayoutBox(0, merge_boxes(current_group), max(group_confidences))
merged.append(merged_box)
else:
idx = text_title_boxes[i][0]
merged.append(data[idx])
i += 1
remaining = [data[i] for i in range(len(data)) if i not in skip_indices and data[i].clsid not in merged_labels]
return merged + remaining

@ -0,0 +1,649 @@
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import division
import os
import cv2
import math
import numpy as np
import PIL
from PIL import Image, ImageDraw, ImageFile
ImageFile.LOAD_TRUNCATED_IMAGES = True
def imagedraw_textsize_c(draw, text):
if int(PIL.__version__.split('.')[0]) < 10:
tw, th = draw.textsize(text)
else:
left, top, right, bottom = draw.textbbox((0, 0), text)
tw, th = right - left, bottom - top
return tw, th
def visualize_box_mask(im, results, labels, threshold=0.5):
"""
Args:
im (str/np.ndarray): path of image/np.ndarray read by cv2
results (dict): include 'boxes': np.ndarray: shape:[N,6], N: number of box,
matix element:[class, score, x_min, y_min, x_max, y_max]
MaskRCNN's results include 'masks': np.ndarray:
shape:[N, im_h, im_w]
labels (list): labels:['class1', ..., 'classn']
threshold (float): Threshold of score.
Returns:
im (PIL.Image.Image): visualized image
"""
if isinstance(im, str):
im = Image.open(im).convert('RGB')
elif isinstance(im, np.ndarray):
im = Image.fromarray(im)
if 'masks' in results and 'boxes' in results and len(results['boxes']) > 0:
im = draw_mask(
im, results['boxes'], results['masks'], labels, threshold=threshold)
if 'boxes' in results and len(results['boxes']) > 0:
im = draw_box(im, results['boxes'], labels, threshold=threshold)
if 'segm' in results:
im = draw_segm(
im,
results['segm'],
results['label'],
results['score'],
labels,
threshold=threshold)
return im
def get_color_map_list(num_classes):
"""
Args:
num_classes (int): number of class
Returns:
color_map (list): RGB color list
"""
color_map = num_classes * [0, 0, 0]
for i in range(0, num_classes):
j = 0
lab = i
while lab:
color_map[i * 3] |= (((lab >> 0) & 1) << (7 - j))
color_map[i * 3 + 1] |= (((lab >> 1) & 1) << (7 - j))
color_map[i * 3 + 2] |= (((lab >> 2) & 1) << (7 - j))
j += 1
lab >>= 3
color_map = [color_map[i:i + 3] for i in range(0, len(color_map), 3)]
return color_map
def draw_mask(im, np_boxes, np_masks, labels, threshold=0.5):
"""
Args:
im (PIL.Image.Image): PIL image
np_boxes (np.ndarray): shape:[N,6], N: number of box,
matix element:[class, score, x_min, y_min, x_max, y_max]
np_masks (np.ndarray): shape:[N, im_h, im_w]
labels (list): labels:['class1', ..., 'classn']
threshold (float): threshold of mask
Returns:
im (PIL.Image.Image): visualized image
"""
color_list = get_color_map_list(len(labels))
w_ratio = 0.4
alpha = 0.7
im = np.array(im).astype('float32')
clsid2color = {}
expect_boxes = (np_boxes[:, 1] > threshold) & (np_boxes[:, 0] > -1)
np_boxes = np_boxes[expect_boxes, :]
np_masks = np_masks[expect_boxes, :, :]
im_h, im_w = im.shape[:2]
np_masks = np_masks[:, :im_h, :im_w]
for i in range(len(np_masks)):
clsid, score = int(np_boxes[i][0]), np_boxes[i][1]
mask = np_masks[i]
if clsid not in clsid2color:
clsid2color[clsid] = color_list[clsid]
color_mask = clsid2color[clsid]
for c in range(3):
color_mask[c] = color_mask[c] * (1 - w_ratio) + w_ratio * 255
idx = np.nonzero(mask)
color_mask = np.array(color_mask)
im[idx[0], idx[1], :] *= 1.0 - alpha
im[idx[0], idx[1], :] += alpha * color_mask
return Image.fromarray(im.astype('uint8'))
def draw_box(im, np_boxes, labels, threshold=0.5):
"""
Args:
im (PIL.Image.Image): PIL image
np_boxes (np.ndarray): shape:[N,6], N: number of box,
matix element:[class, score, x_min, y_min, x_max, y_max]
labels (list): labels:['class1', ..., 'classn']
threshold (float): threshold of box
Returns:
im (PIL.Image.Image): visualized image
"""
draw_thickness = min(im.size) // 320
draw = ImageDraw.Draw(im)
clsid2color = {}
color_list = get_color_map_list(len(labels))
expect_boxes = (np_boxes[:, 1] > threshold) & (np_boxes[:, 0] > -1)
np_boxes = np_boxes[expect_boxes, :]
for dt in np_boxes:
clsid, bbox, score = int(dt[0]), dt[2:], dt[1]
if clsid not in clsid2color:
clsid2color[clsid] = color_list[clsid]
color = tuple(clsid2color[clsid])
if len(bbox) == 4:
xmin, ymin, xmax, ymax = bbox
print('class_id:{:d}, confidence:{:.4f}, left_top:[{:.2f},{:.2f}],'
'right_bottom:[{:.2f},{:.2f}]'.format(
int(clsid), score, xmin, ymin, xmax, ymax))
# draw bbox
draw.line(
[(xmin, ymin), (xmin, ymax), (xmax, ymax), (xmax, ymin),
(xmin, ymin)],
width=draw_thickness,
fill=color)
elif len(bbox) == 8:
x1, y1, x2, y2, x3, y3, x4, y4 = bbox
draw.line(
[(x1, y1), (x2, y2), (x3, y3), (x4, y4), (x1, y1)],
width=2,
fill=color)
xmin = min(x1, x2, x3, x4)
ymin = min(y1, y2, y3, y4)
# draw label
text = "{} {:.4f}".format(labels[clsid], score)
tw, th = imagedraw_textsize_c(draw, text)
draw.rectangle(
[(xmin + 1, ymin - th), (xmin + tw + 1, ymin)], fill=color)
draw.text((xmin + 1, ymin - th), text, fill=(255, 255, 255))
return im
def draw_segm(im,
np_segms,
np_label,
np_score,
labels,
threshold=0.5,
alpha=0.7):
"""
Draw segmentation on image
"""
mask_color_id = 0
w_ratio = .4
color_list = get_color_map_list(len(labels))
im = np.array(im).astype('float32')
clsid2color = {}
np_segms = np_segms.astype(np.uint8)
for i in range(np_segms.shape[0]):
mask, score, clsid = np_segms[i], np_score[i], np_label[i]
if score < threshold:
continue
if clsid not in clsid2color:
clsid2color[clsid] = color_list[clsid]
color_mask = clsid2color[clsid]
for c in range(3):
color_mask[c] = color_mask[c] * (1 - w_ratio) + w_ratio * 255
idx = np.nonzero(mask)
color_mask = np.array(color_mask)
idx0 = np.minimum(idx[0], im.shape[0] - 1)
idx1 = np.minimum(idx[1], im.shape[1] - 1)
im[idx0, idx1, :] *= 1.0 - alpha
im[idx0, idx1, :] += alpha * color_mask
sum_x = np.sum(mask, axis=0)
x = np.where(sum_x > 0.5)[0]
sum_y = np.sum(mask, axis=1)
y = np.where(sum_y > 0.5)[0]
x0, x1, y0, y1 = x[0], x[-1], y[0], y[-1]
cv2.rectangle(im, (x0, y0), (x1, y1),
tuple(color_mask.astype('int32').tolist()), 1)
bbox_text = '%s %.2f' % (labels[clsid], score)
t_size = cv2.getTextSize(bbox_text, 0, 0.3, thickness=1)[0]
cv2.rectangle(im, (x0, y0), (x0 + t_size[0], y0 - t_size[1] - 3),
tuple(color_mask.astype('int32').tolist()), -1)
cv2.putText(
im,
bbox_text, (x0, y0 - 2),
cv2.FONT_HERSHEY_SIMPLEX,
0.3, (0, 0, 0),
1,
lineType=cv2.LINE_AA)
return Image.fromarray(im.astype('uint8'))
def get_color(idx):
idx = idx * 3
color = ((37 * idx) % 255, (17 * idx) % 255, (29 * idx) % 255)
return color
def visualize_pose(imgfile,
results,
visual_thresh=0.6,
save_name='pose.jpg',
save_dir='output',
returnimg=False,
ids=None):
try:
import matplotlib.pyplot as plt
import matplotlib
plt.switch_backend('agg')
except Exception as e:
print('Matplotlib not found, please install matplotlib.'
'for example: `pip install matplotlib`.')
raise e
skeletons, scores = results['keypoint']
skeletons = np.array(skeletons)
kpt_nums = 17
if len(skeletons) > 0:
kpt_nums = skeletons.shape[1]
if kpt_nums == 17: #plot coco keypoint
EDGES = [(0, 1), (0, 2), (1, 3), (2, 4), (3, 5), (4, 6), (5, 7), (6, 8),
(7, 9), (8, 10), (5, 11), (6, 12), (11, 13), (12, 14),
(13, 15), (14, 16), (11, 12)]
else: #plot mpii keypoint
EDGES = [(0, 1), (1, 2), (3, 4), (4, 5), (2, 6), (3, 6), (6, 7), (7, 8),
(8, 9), (10, 11), (11, 12), (13, 14), (14, 15), (8, 12),
(8, 13)]
NUM_EDGES = len(EDGES)
colors = [[255, 0, 0], [255, 85, 0], [255, 170, 0], [255, 255, 0], [170, 255, 0], [85, 255, 0], [0, 255, 0], \
[0, 255, 85], [0, 255, 170], [0, 255, 255], [0, 170, 255], [0, 85, 255], [0, 0, 255], [85, 0, 255], \
[170, 0, 255], [255, 0, 255], [255, 0, 170], [255, 0, 85]]
cmap = matplotlib.cm.get_cmap('hsv')
plt.figure()
img = cv2.imread(imgfile) if type(imgfile) == str else imgfile
color_set = results['colors'] if 'colors' in results else None
if 'bbox' in results and ids is None:
bboxs = results['bbox']
for j, rect in enumerate(bboxs):
xmin, ymin, xmax, ymax = rect
color = colors[0] if color_set is None else colors[color_set[j] %
len(colors)]
cv2.rectangle(img, (xmin, ymin), (xmax, ymax), color, 1)
canvas = img.copy()
for i in range(kpt_nums):
for j in range(len(skeletons)):
if skeletons[j][i, 2] < visual_thresh:
continue
if ids is None:
color = colors[i] if color_set is None else colors[color_set[j]
%
len(colors)]
else:
color = get_color(ids[j])
cv2.circle(
canvas,
tuple(skeletons[j][i, 0:2].astype('int32')),
2,
color,
thickness=-1)
to_plot = cv2.addWeighted(img, 0.3, canvas, 0.7, 0)
fig = matplotlib.pyplot.gcf()
stickwidth = 2
for i in range(NUM_EDGES):
for j in range(len(skeletons)):
edge = EDGES[i]
if skeletons[j][edge[0], 2] < visual_thresh or skeletons[j][edge[
1], 2] < visual_thresh:
continue
cur_canvas = canvas.copy()
X = [skeletons[j][edge[0], 1], skeletons[j][edge[1], 1]]
Y = [skeletons[j][edge[0], 0], skeletons[j][edge[1], 0]]
mX = np.mean(X)
mY = np.mean(Y)
length = ((X[0] - X[1])**2 + (Y[0] - Y[1])**2)**0.5
angle = math.degrees(math.atan2(X[0] - X[1], Y[0] - Y[1]))
polygon = cv2.ellipse2Poly((int(mY), int(mX)),
(int(length / 2), stickwidth),
int(angle), 0, 360, 1)
if ids is None:
color = colors[i] if color_set is None else colors[color_set[j]
%
len(colors)]
else:
color = get_color(ids[j])
cv2.fillConvexPoly(cur_canvas, polygon, color)
canvas = cv2.addWeighted(canvas, 0.4, cur_canvas, 0.6, 0)
if returnimg:
return canvas
save_name = os.path.join(
save_dir, os.path.splitext(os.path.basename(imgfile))[0] + '_vis.jpg')
plt.imsave(save_name, canvas[:, :, ::-1])
print("keypoint visualize image saved to: " + save_name)
plt.close()
def visualize_attr(im, results, boxes=None, is_mtmct=False):
if isinstance(im, str):
im = Image.open(im)
im = np.ascontiguousarray(np.copy(im))
im = cv2.cvtColor(im, cv2.COLOR_RGB2BGR)
else:
im = np.ascontiguousarray(np.copy(im))
im_h, im_w = im.shape[:2]
text_scale = max(0.5, im.shape[0] / 3000.)
text_thickness = 1
line_inter = im.shape[0] / 40.
for i, res in enumerate(results):
if boxes is None:
text_w = 3
text_h = 1
elif is_mtmct:
box = boxes[i] # multi camera, bbox shape is x,y, w,h
text_w = int(box[0]) + 3
text_h = int(box[1])
else:
box = boxes[i] # single camera, bbox shape is 0, 0, x,y, w,h
text_w = int(box[2]) + 3
text_h = int(box[3])
for text in res:
text_h += int(line_inter)
text_loc = (text_w, text_h)
cv2.putText(
im,
text,
text_loc,
cv2.FONT_ITALIC,
text_scale, (0, 255, 255),
thickness=text_thickness)
return im
def visualize_action(im,
mot_boxes,
action_visual_collector=None,
action_text="",
video_action_score=None,
video_action_text=""):
im = cv2.imread(im) if isinstance(im, str) else im
im_h, im_w = im.shape[:2]
text_scale = max(1, im.shape[1] / 400.)
text_thickness = 2
if action_visual_collector:
id_action_dict = {}
for collector, action_type in zip(action_visual_collector, action_text):
id_detected = collector.get_visualize_ids()
for pid in id_detected:
id_action_dict[pid] = id_action_dict.get(pid, [])
id_action_dict[pid].append(action_type)
for mot_box in mot_boxes:
# mot_box is a format with [mot_id, class, score, xmin, ymin, w, h]
if mot_box[0] in id_action_dict:
text_position = (int(mot_box[3] + mot_box[5] * 0.75),
int(mot_box[4] - 10))
display_text = ', '.join(id_action_dict[mot_box[0]])
cv2.putText(im, display_text, text_position,
cv2.FONT_HERSHEY_PLAIN, text_scale, (0, 0, 255), 2)
if video_action_score:
cv2.putText(
im,
video_action_text + ': %.2f' % video_action_score,
(int(im_w / 2), int(15 * text_scale) + 5),
cv2.FONT_ITALIC,
text_scale, (0, 0, 255),
thickness=text_thickness)
return im
def visualize_vehicleplate(im, results, boxes=None):
if isinstance(im, str):
im = Image.open(im)
im = np.ascontiguousarray(np.copy(im))
im = cv2.cvtColor(im, cv2.COLOR_RGB2BGR)
else:
im = np.ascontiguousarray(np.copy(im))
im_h, im_w = im.shape[:2]
text_scale = max(1.0, im.shape[0] / 400.)
text_thickness = 2
line_inter = im.shape[0] / 40.
for i, res in enumerate(results):
if boxes is None:
text_w = 3
text_h = 1
else:
box = boxes[i]
text = res
if text == "":
continue
text_w = int(box[2])
text_h = int(box[5] + box[3])
text_loc = (text_w, text_h)
cv2.putText(
im,
"LP: " + text,
text_loc,
cv2.FONT_ITALIC,
text_scale, (0, 255, 255),
thickness=text_thickness)
return im
def draw_press_box_lanes(im, np_boxes, labels, threshold=0.5):
"""
Args:
im (PIL.Image.Image): PIL image
np_boxes (np.ndarray): shape:[N,6], N: number of box,
matix element:[class, score, x_min, y_min, x_max, y_max]
labels (list): labels:['class1', ..., 'classn']
threshold (float): threshold of box
Returns:
im (PIL.Image.Image): visualized image
"""
if isinstance(im, str):
im = Image.open(im).convert('RGB')
elif isinstance(im, np.ndarray):
im = Image.fromarray(im)
draw_thickness = min(im.size) // 320
draw = ImageDraw.Draw(im)
clsid2color = {}
color_list = get_color_map_list(len(labels))
if np_boxes.shape[1] == 7:
np_boxes = np_boxes[:, 1:]
expect_boxes = (np_boxes[:, 1] > threshold) & (np_boxes[:, 0] > -1)
np_boxes = np_boxes[expect_boxes, :]
for dt in np_boxes:
clsid, bbox, score = int(dt[0]), dt[2:], dt[1]
if clsid not in clsid2color:
clsid2color[clsid] = color_list[clsid]
color = tuple(clsid2color[clsid])
if len(bbox) == 4:
xmin, ymin, xmax, ymax = bbox
# draw bbox
draw.line(
[(xmin, ymin), (xmin, ymax), (xmax, ymax), (xmax, ymin),
(xmin, ymin)],
width=draw_thickness,
fill=(0, 0, 255))
elif len(bbox) == 8:
x1, y1, x2, y2, x3, y3, x4, y4 = bbox
draw.line(
[(x1, y1), (x2, y2), (x3, y3), (x4, y4), (x1, y1)],
width=2,
fill=color)
xmin = min(x1, x2, x3, x4)
ymin = min(y1, y2, y3, y4)
# draw label
text = "{}".format(labels[clsid])
tw, th = imagedraw_textsize_c(draw, text)
draw.rectangle(
[(xmin + 1, ymax - th), (xmin + tw + 1, ymax)], fill=color)
draw.text((xmin + 1, ymax - th), text, fill=(0, 0, 255))
return im
def visualize_vehiclepress(im, results, threshold=0.5):
results = np.array(results)
labels = ['violation']
im = draw_press_box_lanes(im, results, labels, threshold=threshold)
return im
def visualize_lane(im, lanes):
if isinstance(im, str):
im = Image.open(im).convert('RGB')
elif isinstance(im, np.ndarray):
im = Image.fromarray(im)
draw_thickness = min(im.size) // 320
draw = ImageDraw.Draw(im)
if len(lanes) > 0:
for lane in lanes:
draw.line(
[(lane[0], lane[1]), (lane[2], lane[3])],
width=draw_thickness,
fill=(0, 0, 255))
return im
def visualize_vehicle_retrograde(im, mot_res, vehicle_retrograde_res):
if isinstance(im, str):
im = Image.open(im).convert('RGB')
elif isinstance(im, np.ndarray):
im = Image.fromarray(im)
draw_thickness = min(im.size) // 320
draw = ImageDraw.Draw(im)
lane = vehicle_retrograde_res['fence_line']
if lane is not None:
draw.line(
[(lane[0], lane[1]), (lane[2], lane[3])],
width=draw_thickness,
fill=(0, 0, 0))
mot_id = vehicle_retrograde_res['output']
if mot_id is None or len(mot_id) == 0:
return im
if mot_res is None:
return im
np_boxes = mot_res['boxes']
if np_boxes is not None:
for dt in np_boxes:
if dt[0] not in mot_id:
continue
bbox = dt[3:]
if len(bbox) == 4:
xmin, ymin, xmax, ymax = bbox
# draw bbox
draw.line(
[(xmin, ymin), (xmin, ymax), (xmax, ymax), (xmax, ymin),
(xmin, ymin)],
width=draw_thickness,
fill=(0, 255, 0))
# draw label
text = "retrograde"
tw, th = imagedraw_textsize_c(draw, text)
draw.rectangle(
[(xmax + 1, ymin - th), (xmax + tw + 1, ymin)],
fill=(0, 255, 0))
draw.text((xmax + 1, ymin - th), text, fill=(0, 255, 0))
return im
COLORS = [
(255, 0, 0),
(0, 255, 0),
(0, 0, 255),
(255, 255, 0),
(255, 0, 255),
(0, 255, 255),
(128, 255, 0),
(255, 128, 0),
(128, 0, 255),
(255, 0, 128),
(0, 128, 255),
(0, 255, 128),
(128, 255, 255),
(255, 128, 255),
(255, 255, 128),
(60, 180, 0),
(180, 60, 0),
(0, 60, 180),
(0, 180, 60),
(60, 0, 180),
(180, 0, 60),
(255, 0, 0),
(0, 255, 0),
(0, 0, 255),
(255, 255, 0),
(255, 0, 255),
(0, 255, 255),
(128, 255, 0),
(255, 128, 0),
(128, 0, 255),
]
def imshow_lanes(img, lanes, show=False, out_file=None, width=4):
lanes_xys = []
for _, lane in enumerate(lanes):
xys = []
for x, y in lane:
if x <= 0 or y <= 0:
continue
x, y = int(x), int(y)
xys.append((x, y))
lanes_xys.append(xys)
lanes_xys.sort(key=lambda xys: xys[0][0] if len(xys) > 0 else 0)
for idx, xys in enumerate(lanes_xys):
for i in range(1, len(xys)):
cv2.line(img, xys[i - 1], xys[i], COLORS[idx], thickness=width)
if show:
cv2.imshow('view', img)
cv2.waitKey(0)
if out_file:
if not os.path.exists(os.path.dirname(out_file)):
os.makedirs(os.path.dirname(out_file))
cv2.imwrite(out_file, img)

@ -0,0 +1,62 @@
{
"bucket_info": {
"bucket-name-1": [
"ak",
"sk",
"endpoint"
],
"bucket-name-2": [
"ak",
"sk",
"endpoint"
]
},
"models-dir": "/root/.cache/modelscope/hub/models/opendatalab/PDF-Extract-Kit-1___0/models",
"layoutreader-model-dir": "/root/.cache/modelscope/hub/models/ppaanngggg/layoutreader",
"device-mode": "cuda",
"layout-config": {
"model": "doclayout_yolo"
},
"formula-config": {
"mfd_model": "yolo_v8_mfd",
"mfr_model": "unimernet_small",
"enable": false
},
"table-config": {
"model": "rapid_table",
"sub_model": "slanet_plus",
"enable": false,
"max_time": 400
},
"latex-delimiter-config": {
"display": {
"left": "$$",
"right": "$$"
},
"inline": {
"left": "$",
"right": "$"
}
},
"llm-aided-config": {
"formula_aided": {
"api_key": "your_api_key",
"base_url": "https://dashscope.aliyuncs.com/compatible-mode/v1",
"model": "qwen2.5-7b-instruct",
"enable": false
},
"text_aided": {
"api_key": "your_api_key",
"base_url": "https://dashscope.aliyuncs.com/compatible-mode/v1",
"model": "qwen2.5-7b-instruct",
"enable": false
},
"title_aided": {
"api_key": "your_api_key",
"base_url": "https://dashscope.aliyuncs.com/compatible-mode/v1",
"model": "qwen2.5-32b-instruct",
"enable": false
}
},
"config_version": "1.2.1"
}

@ -0,0 +1,48 @@
mode: paddle
draw_threshold: 0.5
metric: COCO
use_dynamic_shape: false
arch: PicoDet
min_subgraph_size: 3
Preprocess:
- interp: 2
keep_ratio: false
target_size:
- 800
- 608
type: Resize
- is_scale: true
mean:
- 0.485
- 0.456
- 0.406
std:
- 0.229
- 0.224
- 0.225
type: NormalizeImage
- type: Permute
- stride: 32
type: PadStride
label_list:
- Text
- Title
- Figure
- Figure caption
- Table
- Table caption
- Header
- Footer
- Reference
- Equation
NMS:
keep_top_k: 100
name: MultiClassNMS
nms_threshold: 0.5
nms_top_k: 1000
score_threshold: 0.3
fpn_stride:
- 8
- 16
- 32
- 64

@ -0,0 +1,43 @@
mode: paddle
draw_threshold: 0.5
metric: COCO
use_dynamic_shape: false
arch: PicoDet
min_subgraph_size: 3
Preprocess:
- interp: 2
keep_ratio: false
target_size:
- 800
- 608
type: Resize
- is_scale: true
mean:
- 0.485
- 0.456
- 0.406
std:
- 0.229
- 0.224
- 0.225
type: NormalizeImage
- type: Permute
- stride: 32
type: PadStride
label_list:
- text
- title
- list
- table
- figure
NMS:
keep_top_k: 100
name: MultiClassNMS
nms_threshold: 0.5
nms_top_k: 1000
score_threshold: 0.3
fpn_stride:
- 8
- 16
- 32
- 64

@ -0,0 +1,39 @@
mode: paddle
draw_threshold: 0.5
metric: COCO
use_dynamic_shape: false
arch: PicoDet
min_subgraph_size: 3
Preprocess:
- interp: 2
keep_ratio: false
target_size:
- 800
- 608
type: Resize
- is_scale: true
mean:
- 0.485
- 0.456
- 0.406
std:
- 0.229
- 0.224
- 0.225
type: NormalizeImage
- type: Permute
- stride: 32
type: PadStride
label_list:
- table
NMS:
keep_top_k: 100
name: MultiClassNMS
nms_threshold: 0.5
nms_top_k: 1000
score_threshold: 0.3
fpn_stride:
- 8
- 16
- 32
- 64

@ -0,0 +1,43 @@
mode: paddle
draw_threshold: 0.5
metric: COCO
use_dynamic_shape: false
arch: PicoDet
min_subgraph_size: 3
Preprocess:
- interp: 2
keep_ratio: false
target_size:
- 800
- 608
type: Resize
- is_scale: true
mean:
- 0.485
- 0.456
- 0.406
std:
- 0.229
- 0.224
- 0.225
type: NormalizeImage
- type: Permute
- stride: 32
type: PadStride
label_list:
- text
- title
- list
- table
- figure
NMS:
keep_top_k: 100
name: MultiClassNMS
nms_threshold: 0.5
nms_top_k: 1000
score_threshold: 0.3
fpn_stride:
- 8
- 16
- 32
- 64

@ -0,0 +1,94 @@
from dotenv import load_dotenv
import os
env = os.environ.get('env', 'dev')
load_dotenv(dotenv_path='.env.dev' if env == 'dev' else '.env', override=True)
import time
import traceback
import cv2
from helper.image_helper import pdf2image, image_orient_cls, page_detection_visual
from helper.page_detection.main import layout_analysis
from helper.content_recognition.main import rec
from helper.db_helper import insert_pdf2md_table
import tempfile
from loguru import logger
import datetime
import shutil
def _pdf2markdown_pipeline(pdf_path, tmp_dir):
start_time = time.time()
# 1. pdf -> images
t1 = time.time()
pdf2image(pdf_path, tmp_dir)
t2 = time.time()
# 2. 图片方向分类
t3 = time.time()
orient_cls_results = image_orient_cls(tmp_dir)
t4 = time.time()
for r in orient_cls_results:
clsid = r[0]['class_ids'][0]
filename = r[0]['filename']
if clsid == 1 or clsid == 3:
img = cv2.imread(filename)
img = cv2.rotate(img, cv2.ROTATE_90_CLOCKWISE)
cv2.imwrite(filename, img)
filepaths = os.listdir(tmp_dir)
filepaths.sort(key=lambda x: int(x.split('/')[-1].split('.')[0]))
filepaths = [f'{tmp_dir}/{_}' for _ in filepaths]
# filepaths = filepaths[:75]
# 3. 版面分析
t5 = time.time()
layout_detection_results = layout_analysis(filepaths)
t6 = time.time()
# 3.1 visual
if int(os.environ['VISUAL']):
visual_dir = './visual_images'
for f in os.listdir(visual_dir):
if f.endswith('.jpg'):
os.remove(f'{visual_dir}/{f}')
for i in range(len(layout_detection_results)):
vis_img = page_detection_visual(layout_detection_results[i])
cv2.imwrite(f'{visual_dir}/{i + 1}.jpg', vis_img)
# 4. 内容识别
t7 = time.time()
layout_recognition_results = rec(layout_detection_results, tmp_dir)
t8 = time.time()
end_time = time.time()
logger.info(f'{pdf_path} analysis completed in {round(end_time - start_time, 3)} seconds, including {round(t2 - t1, 3)} for pdf to image, {round(t4 - t3, 3)} second for image orient classification, {round(t6 - t5, 3)} seconds for page detection, and {round(t8 - t7, 3)} seconds for layout recognition, page number: {len(filepaths)}')
return layout_recognition_results
def pdf2markdown_pipeline(pdf_path: str):
pdf_name = pdf_path.split('/')[-1]
start_time = datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S.%f')
process_status = 0
tmp_dir = tempfile.mkdtemp()
try:
results = _pdf2markdown_pipeline(pdf_path, tmp_dir)
except Exception:
logger.error(f'analysis pdf error! \n{traceback.format_exc()}')
process_status = 3
end_time = datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S.%f')
insert_pdf2md_table(pdf_path, pdf_name, process_status, start_time, end_time, None)
pdf_id = None
else:
process_status = 2
end_time = datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S.%f')
pdf_id = insert_pdf2md_table(pdf_path, pdf_name, process_status, start_time, end_time, results)
finally:
shutil.rmtree(tmp_dir)
return process_status, pdf_id
if __name__ == '__main__':
pdf2markdown_pipeline('/mnt/pdf2markdown/demo.pdf')

@ -0,0 +1,189 @@
albucore==0.0.24
albumentations==2.0.6
annotated-types==0.7.0
anthropic==0.46.0
antlr4-python3-runtime==4.9.3
anyio==4.9.0
astor==0.8.1
babel==2.17.0
bce-python-sdk==0.9.29
beautifulsoup4==4.13.4
blinker==1.9.0
boto3==1.38.8
botocore==1.38.8
Brotli==1.1.0
cachetools==5.5.2
certifi==2025.4.26
cffi==1.17.1
cfgv==3.4.0
charset-normalizer==3.4.1
click==8.1.8
cobble==0.1.4
coloredlogs==15.0.1
colorlog==6.9.0
contourpy==1.3.2
cryptography==44.0.3
cssselect2==0.8.0
cycler==0.12.1
Cython==3.0.12
decorator==5.2.1
dill==0.4.0
distlib==0.3.9
distro==1.9.0
doclayout_yolo==0.0.2b1
easydict==1.13
EbookLib==0.18
et_xmlfile==2.0.0
faiss-cpu==1.8.0.post1
fast-langdetect==0.2.5
fasttext-predict==0.9.2.4
filelock==3.18.0
filetype==1.2.0
fire==0.7.0
Flask==3.1.0
flask-babel==4.0.0
flatbuffers==25.2.10
fonttools==4.57.0
fsspec==2025.3.2
ftfy==6.3.1
future==1.0.0
gast==0.3.3
google-auth==2.39.0
google-genai==1.13.0
h11==0.16.0
httpcore==1.0.9
httpx==0.28.1
huggingface-hub==0.30.2
humanfriendly==10.0
identify==2.6.10
idna==3.10
imageio==2.37.0
imgaug==0.4.0
itsdangerous==2.2.0
Jinja2==3.1.6
jiter==0.9.0
jmespath==1.0.1
joblib==1.4.2
kiwisolver==1.4.8
lazy_loader==0.4
lmdb==1.6.2
loguru==0.7.3
lxml==5.4.0
magic-pdf==1.3.10
mammoth==1.9.0
markdown2==2.5.3
markdownify==0.13.1
marker-pdf==1.6.2
MarkupSafe==3.0.2
matplotlib==3.10.1
modelscope==1.25.0
mpmath==1.3.0
networkx==3.4.2
nodeenv==1.9.1
numpy==1.24.4
nvidia-cublas-cu12==12.4.5.8
nvidia-cuda-cupti-cu12==12.4.127
nvidia-cuda-nvrtc-cu12==12.4.127
nvidia-cuda-runtime-cu12==12.4.127
nvidia-cudnn-cu12==9.1.0.70
nvidia-cufft-cu12==11.2.1.3
nvidia-curand-cu12==10.3.5.147
nvidia-cusolver-cu12==11.6.1.9
nvidia-cusparse-cu12==12.3.1.170
nvidia-cusparselt-cu12==0.6.2
nvidia-nccl-cu12==2.21.5
nvidia-nvjitlink-cu12==12.4.127
nvidia-nvtx-cu12==12.4.127
omegaconf==2.3.0
onnxruntime==1.21.1
openai==1.77.0
opencv-contrib-python==4.11.0.86
opencv-python==4.6.0.66
opencv-python-headless==4.11.0.86
openpyxl==3.1.5
opt-einsum==3.3.0
packaging==25.0
paddleclas==2.5.2
paddleocr==2.10.0
paddlepaddle-gpu==2.6.2
pandas==2.2.3
pdf2image==1.17.0
pdfminer.six==20250324
pdftext==0.6.2
pillow==10.4.0
platformdirs==4.3.7
pre_commit==4.2.0
prettytable==3.16.0
protobuf==6.30.2
psutil==7.0.0
psycopg2==2.9.10
py-cpuinfo==9.0.0
pyasn1==0.6.1
pyasn1_modules==0.4.2
pyclipper==1.3.0.post6
pycparser==2.22
pycryptodome==3.22.0
pydantic==2.10.6
pydantic-settings==2.9.1
pydantic_core==2.27.2
pydyf==0.11.0
PyMuPDF==1.24.14
pyparsing==3.2.3
pypdfium2==4.30.0
pyphen==0.17.2
python-dateutil==2.9.0.post0
python-docx==1.1.2
python-dotenv==1.1.0
python-pptx==1.0.2
pytz==2025.2
PyYAML==6.0.2
rapid-table==1.0.5
RapidFuzz==3.13.0
rapidocr==2.0.7
rarfile==4.2
regex==2024.11.6
requests==2.32.3
robust-downloader==0.0.2
rsa==4.9.1
s3transfer==0.12.0
safetensors==0.5.3
scikit-image==0.25.2
scikit-learn==1.6.1
scipy==1.15.2
seaborn==0.13.2
shapely==2.1.0
simsimd==6.2.1
six==1.17.0
sniffio==1.3.1
soupsieve==2.7
stringzilla==3.12.5
surya-ocr==0.13.1
sympy==1.13.1
termcolor==3.1.0
thop==0.1.1.post2209072238
threadpoolctl==3.6.0
tifffile==2025.3.30
tinycss2==1.4.0
tinyhtml5==2.0.0
tokenizers==0.21.1
torch==2.6.0
torchvision==0.21.0
tqdm==4.67.1
transformers==4.51.3
triton==3.2.0
typing-inspection==0.4.0
typing_extensions==4.13.2
tzdata==2025.2
ujson==5.10.0
ultralytics==8.3.127
ultralytics-thop==2.0.14
urllib3==2.4.0
virtualenv==20.31.1
visualdl==2.5.3
wcwidth==0.2.13
weasyprint==63.1
webencodings==0.5.1
websockets==15.0.1
Werkzeug==3.1.3
XlsxWriter==3.2.3
zopfli==0.2.3.post1

@ -0,0 +1,16 @@
from flask import Flask, request
import requests
from pipeline import pdf2markdown_pipeline
app = Flask(__name__)
@app.route('/pdf-qa-server/pdf-to-md')
def pdf2markdown():
data = request.json
pdf_paths = data['pathList']
callback_url = data['webhookUrl']
for pdf_path in pdf_paths:
process_status, pdf_id = pdf2markdown_pipeline(pdf_path)
requests.post(callback_url, json={'pdfId': pdf_id, 'processStatus': process_status})
Loading…
Cancel
Save