first commit
parent
c7a0a4a452
commit
76e53818ce
@ -0,0 +1,7 @@
|
||||
POSTGRESQL_HOST=
|
||||
POSTGRESQL_PORT=
|
||||
POSTGRESQL_USERNAME=
|
||||
POSTGRESQL_PASSWORD=
|
||||
POSTGRESQL_DATABASE=
|
||||
|
||||
VISUAL=0
|
@ -0,0 +1,7 @@
|
||||
POSTGRESQL_HOST=192.168.10.137
|
||||
POSTGRESQL_PORT=54321
|
||||
POSTGRESQL_USERNAME=postgres
|
||||
POSTGRESQL_PASSWORD=123456
|
||||
POSTGRESQL_DATABASE=pdf-qa
|
||||
|
||||
VISUAL=1
|
@ -0,0 +1,5 @@
|
||||
venv
|
||||
*.pdf
|
||||
.vscode
|
||||
visual_images/*.jpg
|
||||
__pycache__
|
@ -0,0 +1,67 @@
|
||||
import json
|
||||
import shutil
|
||||
import os
|
||||
|
||||
import requests
|
||||
from modelscope import snapshot_download
|
||||
|
||||
|
||||
def download_json(url):
|
||||
# 下载JSON文件
|
||||
response = requests.get(url)
|
||||
response.raise_for_status() # 检查请求是否成功
|
||||
return response.json()
|
||||
|
||||
|
||||
def download_and_modify_json(url, local_filename, modifications):
|
||||
if os.path.exists(local_filename):
|
||||
data = json.load(open(local_filename))
|
||||
config_version = data.get('config_version', '0.0.0')
|
||||
if config_version < '1.2.0':
|
||||
data = download_json(url)
|
||||
else:
|
||||
data = download_json(url)
|
||||
|
||||
# 修改内容
|
||||
for key, value in modifications.items():
|
||||
data[key] = value
|
||||
|
||||
# 保存修改后的内容
|
||||
with open(local_filename, 'w', encoding='utf-8') as f:
|
||||
json.dump(data, f, ensure_ascii=False, indent=4)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
mineru_patterns = [
|
||||
# "models/Layout/LayoutLMv3/*",
|
||||
"models/Layout/YOLO/*",
|
||||
"models/MFD/YOLO/*",
|
||||
"models/MFR/unimernet_hf_small_2503/*",
|
||||
"models/OCR/paddleocr_torch/*",
|
||||
# "models/TabRec/TableMaster/*",
|
||||
# "models/TabRec/StructEqTable/*",
|
||||
]
|
||||
model_dir = snapshot_download('opendatalab/PDF-Extract-Kit-1.0', allow_patterns=mineru_patterns)
|
||||
layoutreader_model_dir = snapshot_download('ppaanngggg/layoutreader')
|
||||
model_dir = model_dir + '/models'
|
||||
print(f'model_dir is: {model_dir}')
|
||||
print(f'layoutreader_model_dir is: {layoutreader_model_dir}')
|
||||
|
||||
# paddleocr_model_dir = model_dir + '/OCR/paddleocr'
|
||||
# user_paddleocr_dir = os.path.expanduser('~/.paddleocr')
|
||||
# if os.path.exists(user_paddleocr_dir):
|
||||
# shutil.rmtree(user_paddleocr_dir)
|
||||
# shutil.copytree(paddleocr_model_dir, user_paddleocr_dir)
|
||||
|
||||
json_url = 'https://gcore.jsdelivr.net/gh/opendatalab/MinerU@master/magic-pdf.template.json'
|
||||
config_file_name = 'magic-pdf.json'
|
||||
home_dir = os.path.expanduser('~')
|
||||
config_file = os.path.join(home_dir, config_file_name)
|
||||
|
||||
json_mods = {
|
||||
'models-dir': model_dir,
|
||||
'layoutreader-model-dir': layoutreader_model_dir,
|
||||
}
|
||||
|
||||
download_and_modify_json(json_url, config_file, json_mods)
|
||||
print(f'The configuration file has been configured successfully, the path is: {config_file}')
|
@ -0,0 +1,116 @@
|
||||
from typing import List
|
||||
import cv2
|
||||
from .utils import scanning_document_classify, text_rec, table_rec, scanning_document_rec, markdown_rec, assign_tables_to_titles, remove_watermark
|
||||
from tqdm import tqdm
|
||||
|
||||
|
||||
class LayoutRecognitionResult(object):
|
||||
|
||||
def __init__(self, clsid, content, box, table_title=None):
|
||||
self.clsid = clsid
|
||||
self.content = content
|
||||
self.box = box
|
||||
self.table_title = table_title
|
||||
|
||||
def __repr__(self):
|
||||
return f"[{self.clsid}] {self.content}"
|
||||
|
||||
|
||||
expand_pixel = 10
|
||||
|
||||
|
||||
def rec(page_detection_results, tmp_dir) -> List[List[LayoutRecognitionResult]]:
|
||||
page_recognition_results = []
|
||||
|
||||
for page_idx in tqdm(range(len(page_detection_results)), '文本识别'):
|
||||
results = page_detection_results[page_idx]
|
||||
if not results.boxes:
|
||||
page_recognition_results.append([])
|
||||
continue
|
||||
|
||||
img = cv2.imread(results.image_path)
|
||||
h, w = img.shape[:2]
|
||||
|
||||
for layout in results.boxes:
|
||||
# box往外扩一点便于ocr
|
||||
layout.pos[0] -= expand_pixel
|
||||
layout.pos[1] -= expand_pixel
|
||||
layout.pos[2] += expand_pixel
|
||||
layout.pos[3] += expand_pixel
|
||||
|
||||
layout.pos[0] = max(0, layout.pos[0])
|
||||
layout.pos[1] = max(0, layout.pos[1])
|
||||
layout.pos[2] = min(w, layout.pos[2])
|
||||
layout.pos[3] = min(h, layout.pos[3])
|
||||
|
||||
outputs = []
|
||||
|
||||
is_scanning_document = False
|
||||
for layout in results.boxes:
|
||||
x1, y1, x2, y2 = layout.pos
|
||||
x1, y1, x2, y2 = int(x1), int(y1), int(x2), int(y2)
|
||||
layout_img = img[y1: y2, x1: x2]
|
||||
content = None
|
||||
if layout.clsid == 0:
|
||||
# text
|
||||
content = markdown_rec(layout_img)
|
||||
elif layout.clsid == 2:
|
||||
# figure
|
||||
if scanning_document_classify(layout_img):
|
||||
# 扫描件
|
||||
is_scanning_document = True
|
||||
content, layout_img = scanning_document_rec(layout_img)
|
||||
source_page_no_watermark_img = remove_watermark(cv2.imread(f'{tmp_dir}/{page_idx + 1}.jpg'))
|
||||
elif layout.clsid == 4:
|
||||
# table
|
||||
if scanning_document_classify(layout_img):
|
||||
is_scanning_document = True
|
||||
content, layout_img = scanning_document_rec(layout_img)
|
||||
source_page_no_watermark_img = remove_watermark(cv2.imread(f'{tmp_dir}/{page_idx + 1}.jpg'))
|
||||
else:
|
||||
content = table_rec(layout_img)
|
||||
elif layout.clsid == 5:
|
||||
# table caption
|
||||
ocr_results = text_rec(layout_img)
|
||||
content = ''
|
||||
for o in ocr_results:
|
||||
content += f'{o}\n'
|
||||
while content.endswith('\n'):
|
||||
content = content[:-1]
|
||||
|
||||
if not content:
|
||||
continue
|
||||
|
||||
result = LayoutRecognitionResult(layout.clsid, content, layout.pos)
|
||||
outputs.append(result)
|
||||
|
||||
if is_scanning_document and len(outputs) == 1:
|
||||
# 扫描件额外提取标题
|
||||
h, w = source_page_no_watermark_img.shape[:2]
|
||||
if h > w:
|
||||
title_img = source_page_no_watermark_img[:360, :w, ...]
|
||||
|
||||
# cv2.imwrite(f'/mnt/pdf2markdown/temp/{page_idx + 1}.jpg', title_img)
|
||||
# vis = cv2.rectangle(source_page_no_watermark_img.copy(), (0, 0), (w, 360), (255, 255, 0), 3)
|
||||
# cv2.imwrite(f'/mnt/pdf2markdown/temp/{page_idx + 1}-vis.jpg', vis)
|
||||
else:
|
||||
title_img = source_page_no_watermark_img[:410, :w, ...]
|
||||
|
||||
# cv2.imwrite(f'/mnt/pdf2markdown/temp/{page_idx + 1}.jpg', title_img)
|
||||
# vis = cv2.rectangle(source_page_no_watermark_img.copy(), (0, 310), (w, 410), (255, 255, 0), 3)
|
||||
# cv2.imwrite(f'/mnt/pdf2markdown/temp/{page_idx + 1}-vis.jpg', vis)
|
||||
title = text_rec(title_img)
|
||||
outputs[0].table_title = '\n'.join(title)
|
||||
else:
|
||||
# 自动给表格分配距离它最近的标题
|
||||
assign_tables_to_titles(outputs)
|
||||
|
||||
# 表格标题可以删掉了
|
||||
outputs = [_ for _ in outputs if _.clsid != 5]
|
||||
# 将2-图片 和 4-表格转为数据库中的枚举 1-表格
|
||||
for o in outputs:
|
||||
if o.clsid == 2 or o.clsid == 4:
|
||||
o.clsid = 1
|
||||
page_recognition_results.append(outputs)
|
||||
|
||||
return page_recognition_results
|
@ -0,0 +1,5 @@
|
||||
# -*- encoding: utf-8 -*-
|
||||
# @Author: SWHL
|
||||
# @Contact: liekkaskono@163.com
|
||||
from .main import RapidTable, RapidTableInput
|
||||
from .utils import VisTable
|
@ -0,0 +1,258 @@
|
||||
# -*- encoding: utf-8 -*-
|
||||
# @Author: SWHL
|
||||
# @Contact: liekkaskono@163.com
|
||||
import argparse
|
||||
import copy
|
||||
import importlib
|
||||
import time
|
||||
from dataclasses import asdict, dataclass
|
||||
from enum import Enum
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Optional, Tuple, Union
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
|
||||
from rapid_table.utils import DownloadModel, LoadImage, Logger, VisTable
|
||||
|
||||
from .table_matcher import TableMatch
|
||||
from .table_structure import TableStructurer, TableStructureUnitable
|
||||
from markdownify import markdownify as md
|
||||
|
||||
logger = Logger(logger_name=__name__).get_log()
|
||||
root_dir = Path(__file__).resolve().parent
|
||||
|
||||
|
||||
class ModelType(Enum):
|
||||
PPSTRUCTURE_EN = "ppstructure_en"
|
||||
PPSTRUCTURE_ZH = "ppstructure_zh"
|
||||
SLANETPLUS = "slanet_plus"
|
||||
UNITABLE = "unitable"
|
||||
|
||||
|
||||
ROOT_URL = "https://www.modelscope.cn/models/RapidAI/RapidTable/resolve/master/"
|
||||
KEY_TO_MODEL_URL = {
|
||||
ModelType.PPSTRUCTURE_EN.value: f"{ROOT_URL}/en_ppstructure_mobile_v2_SLANet.onnx",
|
||||
ModelType.PPSTRUCTURE_ZH.value: f"{ROOT_URL}/ch_ppstructure_mobile_v2_SLANet.onnx",
|
||||
ModelType.SLANETPLUS.value: f"{ROOT_URL}/slanet-plus.onnx",
|
||||
ModelType.UNITABLE.value: {
|
||||
"encoder": f"{ROOT_URL}/unitable/encoder.pth",
|
||||
"decoder": f"{ROOT_URL}/unitable/decoder.pth",
|
||||
"vocab": f"{ROOT_URL}/unitable/vocab.json",
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
@dataclass
|
||||
class RapidTableInput:
|
||||
model_type: Optional[str] = ModelType.SLANETPLUS.value
|
||||
model_path: Union[str, Path, None, Dict[str, str]] = None
|
||||
use_cuda: bool = False
|
||||
device: str = "cpu"
|
||||
|
||||
|
||||
@dataclass
|
||||
class RapidTableOutput:
|
||||
pred_html: Optional[str] = None
|
||||
cell_bboxes: Optional[np.ndarray] = None
|
||||
logic_points: Optional[np.ndarray] = None
|
||||
elapse: Optional[float] = None
|
||||
|
||||
|
||||
class RapidTable:
|
||||
def __init__(self, config: RapidTableInput):
|
||||
self.model_type = config.model_type
|
||||
if self.model_type not in KEY_TO_MODEL_URL:
|
||||
model_list = ",".join(KEY_TO_MODEL_URL)
|
||||
raise ValueError(
|
||||
f"{self.model_type} is not supported. The currently supported models are {model_list}."
|
||||
)
|
||||
|
||||
config.model_path = self.get_model_path(config.model_type, config.model_path)
|
||||
if self.model_type == ModelType.UNITABLE.value:
|
||||
self.table_structure = TableStructureUnitable(asdict(config))
|
||||
else:
|
||||
self.table_structure = TableStructurer(asdict(config))
|
||||
|
||||
self.table_matcher = TableMatch()
|
||||
|
||||
try:
|
||||
self.ocr_engine = importlib.import_module("rapidocr").RapidOCR()
|
||||
except ModuleNotFoundError:
|
||||
self.ocr_engine = None
|
||||
|
||||
self.load_img = LoadImage()
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
img_content: Union[str, np.ndarray, bytes, Path],
|
||||
ocr_result: List[Union[List[List[float]], str, str]] = None,
|
||||
) -> RapidTableOutput:
|
||||
if self.ocr_engine is None and ocr_result is None:
|
||||
raise ValueError(
|
||||
"One of two conditions must be met: ocr_result is not empty, or rapidocr is installed."
|
||||
)
|
||||
|
||||
img = self.load_img(img_content)
|
||||
|
||||
s = time.perf_counter()
|
||||
h, w = img.shape[:2]
|
||||
|
||||
if ocr_result is None:
|
||||
ocr_result = self.ocr_engine(img)
|
||||
ocr_result = list(
|
||||
zip(
|
||||
ocr_result.boxes,
|
||||
ocr_result.txts,
|
||||
ocr_result.scores,
|
||||
)
|
||||
)
|
||||
dt_boxes, rec_res = self.get_boxes_recs(ocr_result, h, w)
|
||||
|
||||
pred_structures, cell_bboxes, _ = self.table_structure(copy.deepcopy(img))
|
||||
|
||||
# 适配slanet-plus模型输出的box缩放还原
|
||||
if self.model_type == ModelType.SLANETPLUS.value:
|
||||
cell_bboxes = self.adapt_slanet_plus(img, cell_bboxes)
|
||||
|
||||
pred_html = self.table_matcher(pred_structures, cell_bboxes, dt_boxes, rec_res)
|
||||
|
||||
# 过滤掉占位的bbox
|
||||
mask = ~np.all(cell_bboxes == 0, axis=1)
|
||||
cell_bboxes = cell_bboxes[mask]
|
||||
|
||||
logic_points = self.table_matcher.decode_logic_points(pred_structures)
|
||||
elapse = time.perf_counter() - s
|
||||
return RapidTableOutput(pred_html, cell_bboxes, logic_points, elapse)
|
||||
|
||||
def get_boxes_recs(
|
||||
self, ocr_result: List[Union[List[List[float]], str, str]], h: int, w: int
|
||||
) -> Tuple[np.ndarray, Tuple[str, str]]:
|
||||
dt_boxes, rec_res, scores = list(zip(*ocr_result))
|
||||
rec_res = list(zip(rec_res, scores))
|
||||
|
||||
r_boxes = []
|
||||
for box in dt_boxes:
|
||||
box = np.array(box)
|
||||
x_min = max(0, box[:, 0].min() - 1)
|
||||
x_max = min(w, box[:, 0].max() + 1)
|
||||
y_min = max(0, box[:, 1].min() - 1)
|
||||
y_max = min(h, box[:, 1].max() + 1)
|
||||
box = [x_min, y_min, x_max, y_max]
|
||||
r_boxes.append(box)
|
||||
dt_boxes = np.array(r_boxes)
|
||||
return dt_boxes, rec_res
|
||||
|
||||
def adapt_slanet_plus(self, img: np.ndarray, cell_bboxes: np.ndarray) -> np.ndarray:
|
||||
h, w = img.shape[:2]
|
||||
resized = 488
|
||||
ratio = min(resized / h, resized / w)
|
||||
w_ratio = resized / (w * ratio)
|
||||
h_ratio = resized / (h * ratio)
|
||||
cell_bboxes[:, 0::2] *= w_ratio
|
||||
cell_bboxes[:, 1::2] *= h_ratio
|
||||
return cell_bboxes
|
||||
|
||||
@staticmethod
|
||||
def get_model_path(
|
||||
model_type: str, model_path: Union[str, Path, None]
|
||||
) -> Union[str, Dict[str, str]]:
|
||||
if model_path is not None:
|
||||
return model_path
|
||||
|
||||
model_url = KEY_TO_MODEL_URL.get(model_type, None)
|
||||
if isinstance(model_url, str):
|
||||
model_path = DownloadModel.download(model_url)
|
||||
return model_path
|
||||
|
||||
if isinstance(model_url, dict):
|
||||
model_paths = {}
|
||||
for k, url in model_url.items():
|
||||
model_paths[k] = DownloadModel.download(
|
||||
url, save_model_name=f"{model_type}_{Path(url).name}"
|
||||
)
|
||||
return model_paths
|
||||
|
||||
raise ValueError(f"Model URL: {type(model_url)} is not between str and dict.")
|
||||
|
||||
|
||||
def parse_args(arg_list: Optional[List[str]] = None):
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"-v",
|
||||
"--vis",
|
||||
action="store_true",
|
||||
default=False,
|
||||
help="Wheter to visualize the layout results.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-img", "--img_path", type=str, required=True, help="Path to image for layout."
|
||||
)
|
||||
parser.add_argument(
|
||||
"-m",
|
||||
"--model_type",
|
||||
type=str,
|
||||
default=ModelType.SLANETPLUS.value,
|
||||
choices=list(KEY_TO_MODEL_URL),
|
||||
)
|
||||
args = parser.parse_args(arg_list)
|
||||
return args
|
||||
|
||||
|
||||
try:
|
||||
ocr_engine = importlib.import_module("rapidocr").RapidOCR()
|
||||
except ModuleNotFoundError as exc:
|
||||
raise ModuleNotFoundError(
|
||||
"Please install the rapidocr by pip install rapidocr"
|
||||
) from exc
|
||||
|
||||
input_args = RapidTableInput(model_type=ModelType.SLANETPLUS.value)
|
||||
table_engine = RapidTable(input_args)
|
||||
|
||||
def table2md_pipeline(img):
|
||||
rapid_ocr_output = ocr_engine(img)
|
||||
ocr_result = list(
|
||||
zip(rapid_ocr_output.boxes, rapid_ocr_output.txts, rapid_ocr_output.scores)
|
||||
)
|
||||
table_results = table_engine(img, ocr_result)
|
||||
html_content = table_results.pred_html
|
||||
md_content = md(html_content)
|
||||
return md_content
|
||||
|
||||
|
||||
# def main(arg_list: Optional[List[str]] = None):
|
||||
# args = parse_args(arg_list)
|
||||
|
||||
# try:
|
||||
# ocr_engine = importlib.import_module("rapidocr").RapidOCR()
|
||||
# except ModuleNotFoundError as exc:
|
||||
# raise ModuleNotFoundError(
|
||||
# "Please install the rapidocr by pip install rapidocr"
|
||||
# ) from exc
|
||||
|
||||
# input_args = RapidTableInput(model_type=args.model_type)
|
||||
# table_engine = RapidTable(input_args)
|
||||
|
||||
# img = cv2.imread(args.img_path)
|
||||
|
||||
# rapid_ocr_output = ocr_engine(img)
|
||||
# ocr_result = list(
|
||||
# zip(rapid_ocr_output.boxes, rapid_ocr_output.txts, rapid_ocr_output.scores)
|
||||
# )
|
||||
# table_results = table_engine(img, ocr_result)
|
||||
# print(table_results.pred_html)
|
||||
|
||||
# viser = VisTable()
|
||||
# if args.vis:
|
||||
# img_path = Path(args.img_path)
|
||||
|
||||
# save_dir = img_path.resolve().parent
|
||||
# save_html_path = save_dir / f"{Path(img_path).stem}.html"
|
||||
# save_drawed_path = save_dir / f"vis_{Path(img_path).name}"
|
||||
# viser(img_path, table_results, save_html_path, save_drawed_path)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
res = table2md_pipeline(cv2.imread('/mnt/research/PaddleOCR/pdf2md_pipeline/s4_content_recognition/all_layouts/11.jpg'))
|
||||
print('*' * 50)
|
||||
print(res)
|
Binary file not shown.
Binary file not shown.
@ -0,0 +1,4 @@
|
||||
# -*- encoding: utf-8 -*-
|
||||
# @Author: SWHL
|
||||
# @Contact: liekkaskono@163.com
|
||||
from .matcher import TableMatch
|
@ -0,0 +1,199 @@
|
||||
# copyright (c) 2022 PaddlePaddle Authors. All Rights Reserve.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# -*- encoding: utf-8 -*-
|
||||
import numpy as np
|
||||
|
||||
from .utils import compute_iou, distance
|
||||
|
||||
|
||||
class TableMatch:
|
||||
def __init__(self, filter_ocr_result=True, use_master=False):
|
||||
self.filter_ocr_result = filter_ocr_result
|
||||
self.use_master = use_master
|
||||
|
||||
def __call__(self, pred_structures, cell_bboxes, dt_boxes, rec_res):
|
||||
if self.filter_ocr_result:
|
||||
dt_boxes, rec_res = self._filter_ocr_result(cell_bboxes, dt_boxes, rec_res)
|
||||
matched_index = self.match_result(dt_boxes, cell_bboxes)
|
||||
pred_html, pred = self.get_pred_html(pred_structures, matched_index, rec_res)
|
||||
return pred_html
|
||||
|
||||
def match_result(self, dt_boxes, cell_bboxes, min_iou=0.1**8):
|
||||
matched = {}
|
||||
for i, gt_box in enumerate(dt_boxes):
|
||||
distances = []
|
||||
for j, pred_box in enumerate(cell_bboxes):
|
||||
if len(pred_box) == 8:
|
||||
pred_box = [
|
||||
np.min(pred_box[0::2]),
|
||||
np.min(pred_box[1::2]),
|
||||
np.max(pred_box[0::2]),
|
||||
np.max(pred_box[1::2]),
|
||||
]
|
||||
distances.append(
|
||||
(distance(gt_box, pred_box), 1.0 - compute_iou(gt_box, pred_box))
|
||||
) # compute iou and l1 distance
|
||||
sorted_distances = distances.copy()
|
||||
# select det box by iou and l1 distance
|
||||
sorted_distances = sorted(
|
||||
sorted_distances, key=lambda item: (item[1], item[0])
|
||||
)
|
||||
# must > min_iou
|
||||
if sorted_distances[0][1] >= 1 - min_iou:
|
||||
continue
|
||||
|
||||
if distances.index(sorted_distances[0]) not in matched:
|
||||
matched[distances.index(sorted_distances[0])] = [i]
|
||||
else:
|
||||
matched[distances.index(sorted_distances[0])].append(i)
|
||||
return matched
|
||||
|
||||
def get_pred_html(self, pred_structures, matched_index, ocr_contents):
|
||||
end_html = []
|
||||
td_index = 0
|
||||
for tag in pred_structures:
|
||||
if "</td>" not in tag:
|
||||
end_html.append(tag)
|
||||
continue
|
||||
|
||||
if "<td></td>" == tag:
|
||||
end_html.extend("<td>")
|
||||
|
||||
if td_index in matched_index.keys():
|
||||
b_with = False
|
||||
if (
|
||||
"<b>" in ocr_contents[matched_index[td_index][0]]
|
||||
and len(matched_index[td_index]) > 1
|
||||
):
|
||||
b_with = True
|
||||
end_html.extend("<b>")
|
||||
|
||||
for i, td_index_index in enumerate(matched_index[td_index]):
|
||||
content = ocr_contents[td_index_index][0]
|
||||
if len(matched_index[td_index]) > 1:
|
||||
if len(content) == 0:
|
||||
continue
|
||||
|
||||
if content[0] == " ":
|
||||
content = content[1:]
|
||||
|
||||
if "<b>" in content:
|
||||
content = content[3:]
|
||||
|
||||
if "</b>" in content:
|
||||
content = content[:-4]
|
||||
|
||||
if len(content) == 0:
|
||||
continue
|
||||
|
||||
if i != len(matched_index[td_index]) - 1 and " " != content[-1]:
|
||||
content += " "
|
||||
end_html.extend(content)
|
||||
|
||||
if b_with:
|
||||
end_html.extend("</b>")
|
||||
|
||||
if "<td></td>" == tag:
|
||||
end_html.append("</td>")
|
||||
else:
|
||||
end_html.append(tag)
|
||||
|
||||
td_index += 1
|
||||
|
||||
# Filter <thead></thead><tbody></tbody> elements
|
||||
filter_elements = ["<thead>", "</thead>", "<tbody>", "</tbody>"]
|
||||
end_html = [v for v in end_html if v not in filter_elements]
|
||||
return "".join(end_html), end_html
|
||||
|
||||
def decode_logic_points(self, pred_structures):
|
||||
logic_points = []
|
||||
current_row = 0
|
||||
current_col = 0
|
||||
max_rows = 0
|
||||
max_cols = 0
|
||||
occupied_cells = {} # 用于记录已经被占用的单元格
|
||||
|
||||
def is_occupied(row, col):
|
||||
return (row, col) in occupied_cells
|
||||
|
||||
def mark_occupied(row, col, rowspan, colspan):
|
||||
for r in range(row, row + rowspan):
|
||||
for c in range(col, col + colspan):
|
||||
occupied_cells[(r, c)] = True
|
||||
|
||||
i = 0
|
||||
while i < len(pred_structures):
|
||||
token = pred_structures[i]
|
||||
|
||||
if token == "<tr>":
|
||||
current_col = 0 # 每次遇到 <tr> 时,重置当前列号
|
||||
elif token == "</tr>":
|
||||
current_row += 1 # 行结束,行号增加
|
||||
elif token.startswith("<td"):
|
||||
colspan = 1
|
||||
rowspan = 1
|
||||
j = i
|
||||
if token != "<td></td>":
|
||||
j += 1
|
||||
# 提取 colspan 和 rowspan 属性
|
||||
while j < len(pred_structures) and not pred_structures[
|
||||
j
|
||||
].startswith(">"):
|
||||
if "colspan=" in pred_structures[j]:
|
||||
colspan = int(pred_structures[j].split("=")[1].strip("\"'"))
|
||||
elif "rowspan=" in pred_structures[j]:
|
||||
rowspan = int(pred_structures[j].split("=")[1].strip("\"'"))
|
||||
j += 1
|
||||
|
||||
# 跳过已经处理过的属性 token
|
||||
i = j
|
||||
|
||||
# 找到下一个未被占用的列
|
||||
while is_occupied(current_row, current_col):
|
||||
current_col += 1
|
||||
|
||||
# 计算逻辑坐标
|
||||
r_start = current_row
|
||||
r_end = current_row + rowspan - 1
|
||||
col_start = current_col
|
||||
col_end = current_col + colspan - 1
|
||||
|
||||
# 记录逻辑坐标
|
||||
logic_points.append([r_start, r_end, col_start, col_end])
|
||||
|
||||
# 标记占用的单元格
|
||||
mark_occupied(r_start, col_start, rowspan, colspan)
|
||||
|
||||
# 更新当前列号
|
||||
current_col += colspan
|
||||
|
||||
# 更新最大行数和列数
|
||||
max_rows = max(max_rows, r_end + 1)
|
||||
max_cols = max(max_cols, col_end + 1)
|
||||
|
||||
i += 1
|
||||
|
||||
return logic_points
|
||||
|
||||
def _filter_ocr_result(self, cell_bboxes, dt_boxes, rec_res):
|
||||
y1 = cell_bboxes[:, 1::2].min()
|
||||
new_dt_boxes = []
|
||||
new_rec_res = []
|
||||
|
||||
for box, rec in zip(dt_boxes, rec_res):
|
||||
if np.max(box[1::2]) < y1:
|
||||
continue
|
||||
new_dt_boxes.append(box)
|
||||
new_rec_res.append(rec)
|
||||
return new_dt_boxes, new_rec_res
|
@ -0,0 +1,249 @@
|
||||
# copyright (c) 2022 PaddlePaddle Authors. All Rights Reserve.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# -*- encoding: utf-8 -*-
|
||||
# @Author: SWHL
|
||||
# @Contact: liekkaskono@163.com
|
||||
import copy
|
||||
import re
|
||||
|
||||
|
||||
def deal_isolate_span(thead_part):
|
||||
"""
|
||||
Deal with isolate span cases in this function.
|
||||
It causes by wrong prediction in structure recognition model.
|
||||
eg. predict <td rowspan="2"></td> to <td></td> rowspan="2"></b></td>.
|
||||
:param thead_part:
|
||||
:return:
|
||||
"""
|
||||
# 1. find out isolate span tokens.
|
||||
isolate_pattern = (
|
||||
'<td></td> rowspan="(\d)+" colspan="(\d)+"></b></td>|'
|
||||
'<td></td> colspan="(\d)+" rowspan="(\d)+"></b></td>|'
|
||||
'<td></td> rowspan="(\d)+"></b></td>|'
|
||||
'<td></td> colspan="(\d)+"></b></td>'
|
||||
)
|
||||
isolate_iter = re.finditer(isolate_pattern, thead_part)
|
||||
isolate_list = [i.group() for i in isolate_iter]
|
||||
|
||||
# 2. find out span number, by step 1 results.
|
||||
span_pattern = (
|
||||
' rowspan="(\d)+" colspan="(\d)+"|'
|
||||
' colspan="(\d)+" rowspan="(\d)+"|'
|
||||
' rowspan="(\d)+"|'
|
||||
' colspan="(\d)+"'
|
||||
)
|
||||
corrected_list = []
|
||||
for isolate_item in isolate_list:
|
||||
span_part = re.search(span_pattern, isolate_item)
|
||||
spanStr_in_isolateItem = span_part.group()
|
||||
# 3. merge the span number into the span token format string.
|
||||
if spanStr_in_isolateItem is not None:
|
||||
corrected_item = f"<td{spanStr_in_isolateItem}></td>"
|
||||
corrected_list.append(corrected_item)
|
||||
else:
|
||||
corrected_list.append(None)
|
||||
|
||||
# 4. replace original isolated token.
|
||||
for corrected_item, isolate_item in zip(corrected_list, isolate_list):
|
||||
if corrected_item is not None:
|
||||
thead_part = thead_part.replace(isolate_item, corrected_item)
|
||||
else:
|
||||
pass
|
||||
return thead_part
|
||||
|
||||
|
||||
def deal_duplicate_bb(thead_part):
|
||||
"""
|
||||
Deal duplicate <b> or </b> after replace.
|
||||
Keep one <b></b> in a <td></td> token.
|
||||
:param thead_part:
|
||||
:return:
|
||||
"""
|
||||
# 1. find out <td></td> in <thead></thead>.
|
||||
td_pattern = (
|
||||
'<td rowspan="(\d)+" colspan="(\d)+">(.+?)</td>|'
|
||||
'<td colspan="(\d)+" rowspan="(\d)+">(.+?)</td>|'
|
||||
'<td rowspan="(\d)+">(.+?)</td>|'
|
||||
'<td colspan="(\d)+">(.+?)</td>|'
|
||||
"<td>(.*?)</td>"
|
||||
)
|
||||
td_iter = re.finditer(td_pattern, thead_part)
|
||||
td_list = [t.group() for t in td_iter]
|
||||
|
||||
# 2. is multiply <b></b> in <td></td> or not?
|
||||
new_td_list = []
|
||||
for td_item in td_list:
|
||||
if td_item.count("<b>") > 1 or td_item.count("</b>") > 1:
|
||||
# multiply <b></b> in <td></td> case.
|
||||
# 1. remove all <b></b>
|
||||
td_item = td_item.replace("<b>", "").replace("</b>", "")
|
||||
# 2. replace <tb> -> <tb><b>, </tb> -> </b></tb>.
|
||||
td_item = td_item.replace("<td>", "<td><b>").replace("</td>", "</b></td>")
|
||||
new_td_list.append(td_item)
|
||||
else:
|
||||
new_td_list.append(td_item)
|
||||
|
||||
# 3. replace original thead part.
|
||||
for td_item, new_td_item in zip(td_list, new_td_list):
|
||||
thead_part = thead_part.replace(td_item, new_td_item)
|
||||
return thead_part
|
||||
|
||||
|
||||
def deal_bb(result_token):
|
||||
"""
|
||||
In our opinion, <b></b> always occurs in <thead></thead> text's context.
|
||||
This function will find out all tokens in <thead></thead> and insert <b></b> by manual.
|
||||
:param result_token:
|
||||
:return:
|
||||
"""
|
||||
# find out <thead></thead> parts.
|
||||
thead_pattern = "<thead>(.*?)</thead>"
|
||||
if re.search(thead_pattern, result_token) is None:
|
||||
return result_token
|
||||
thead_part = re.search(thead_pattern, result_token).group()
|
||||
origin_thead_part = copy.deepcopy(thead_part)
|
||||
|
||||
# check "rowspan" or "colspan" occur in <thead></thead> parts or not .
|
||||
span_pattern = '<td rowspan="(\d)+" colspan="(\d)+">|<td colspan="(\d)+" rowspan="(\d)+">|<td rowspan="(\d)+">|<td colspan="(\d)+">'
|
||||
span_iter = re.finditer(span_pattern, thead_part)
|
||||
span_list = [s.group() for s in span_iter]
|
||||
has_span_in_head = True if len(span_list) > 0 else False
|
||||
|
||||
if not has_span_in_head:
|
||||
# <thead></thead> not include "rowspan" or "colspan" branch 1.
|
||||
# 1. replace <td> to <td><b>, and </td> to </b></td>
|
||||
# 2. it is possible to predict text include <b> or </b> by Text-line recognition,
|
||||
# so we replace <b><b> to <b>, and </b></b> to </b>
|
||||
thead_part = (
|
||||
thead_part.replace("<td>", "<td><b>")
|
||||
.replace("</td>", "</b></td>")
|
||||
.replace("<b><b>", "<b>")
|
||||
.replace("</b></b>", "</b>")
|
||||
)
|
||||
else:
|
||||
# <thead></thead> include "rowspan" or "colspan" branch 2.
|
||||
# Firstly, we deal rowspan or colspan cases.
|
||||
# 1. replace > to ><b>
|
||||
# 2. replace </td> to </b></td>
|
||||
# 3. it is possible to predict text include <b> or </b> by Text-line recognition,
|
||||
# so we replace <b><b> to <b>, and </b><b> to </b>
|
||||
|
||||
# Secondly, deal ordinary cases like branch 1
|
||||
|
||||
# replace ">" to "<b>"
|
||||
replaced_span_list = []
|
||||
for sp in span_list:
|
||||
replaced_span_list.append(sp.replace(">", "><b>"))
|
||||
for sp, rsp in zip(span_list, replaced_span_list):
|
||||
thead_part = thead_part.replace(sp, rsp)
|
||||
|
||||
# replace "</td>" to "</b></td>"
|
||||
thead_part = thead_part.replace("</td>", "</b></td>")
|
||||
|
||||
# remove duplicated <b> by re.sub
|
||||
mb_pattern = "(<b>)+"
|
||||
single_b_string = "<b>"
|
||||
thead_part = re.sub(mb_pattern, single_b_string, thead_part)
|
||||
|
||||
mgb_pattern = "(</b>)+"
|
||||
single_gb_string = "</b>"
|
||||
thead_part = re.sub(mgb_pattern, single_gb_string, thead_part)
|
||||
|
||||
# ordinary cases like branch 1
|
||||
thead_part = thead_part.replace("<td>", "<td><b>").replace("<b><b>", "<b>")
|
||||
|
||||
# convert <tb><b></b></tb> back to <tb></tb>, empty cell has no <b></b>.
|
||||
# but space cell(<tb> </tb>) is suitable for <td><b> </b></td>
|
||||
thead_part = thead_part.replace("<td><b></b></td>", "<td></td>")
|
||||
# deal with duplicated <b></b>
|
||||
thead_part = deal_duplicate_bb(thead_part)
|
||||
# deal with isolate span tokens, which causes by wrong predict by structure prediction.
|
||||
# eg.PMC5994107_011_00.png
|
||||
thead_part = deal_isolate_span(thead_part)
|
||||
# replace original result with new thead part.
|
||||
result_token = result_token.replace(origin_thead_part, thead_part)
|
||||
return result_token
|
||||
|
||||
|
||||
def deal_eb_token(master_token):
|
||||
"""
|
||||
post process with <eb></eb>, <eb1></eb1>, ...
|
||||
emptyBboxTokenDict = {
|
||||
"[]": '<eb></eb>',
|
||||
"[' ']": '<eb1></eb1>',
|
||||
"['<b>', ' ', '</b>']": '<eb2></eb2>',
|
||||
"['\\u2028', '\\u2028']": '<eb3></eb3>',
|
||||
"['<sup>', ' ', '</sup>']": '<eb4></eb4>',
|
||||
"['<b>', '</b>']": '<eb5></eb5>',
|
||||
"['<i>', ' ', '</i>']": '<eb6></eb6>',
|
||||
"['<b>', '<i>', '</i>', '</b>']": '<eb7></eb7>',
|
||||
"['<b>', '<i>', ' ', '</i>', '</b>']": '<eb8></eb8>',
|
||||
"['<i>', '</i>']": '<eb9></eb9>',
|
||||
"['<b>', ' ', '\\u2028', ' ', '\\u2028', ' ', '</b>']": '<eb10></eb10>',
|
||||
}
|
||||
:param master_token:
|
||||
:return:
|
||||
"""
|
||||
master_token = master_token.replace("<eb></eb>", "<td></td>")
|
||||
master_token = master_token.replace("<eb1></eb1>", "<td> </td>")
|
||||
master_token = master_token.replace("<eb2></eb2>", "<td><b> </b></td>")
|
||||
master_token = master_token.replace("<eb3></eb3>", "<td>\u2028\u2028</td>")
|
||||
master_token = master_token.replace("<eb4></eb4>", "<td><sup> </sup></td>")
|
||||
master_token = master_token.replace("<eb5></eb5>", "<td><b></b></td>")
|
||||
master_token = master_token.replace("<eb6></eb6>", "<td><i> </i></td>")
|
||||
master_token = master_token.replace("<eb7></eb7>", "<td><b><i></i></b></td>")
|
||||
master_token = master_token.replace("<eb8></eb8>", "<td><b><i> </i></b></td>")
|
||||
master_token = master_token.replace("<eb9></eb9>", "<td><i></i></td>")
|
||||
master_token = master_token.replace(
|
||||
"<eb10></eb10>", "<td><b> \u2028 \u2028 </b></td>"
|
||||
)
|
||||
return master_token
|
||||
|
||||
|
||||
def distance(box_1, box_2):
|
||||
x1, y1, x2, y2 = box_1
|
||||
x3, y3, x4, y4 = box_2
|
||||
dis = abs(x3 - x1) + abs(y3 - y1) + abs(x4 - x2) + abs(y4 - y2)
|
||||
dis_2 = abs(x3 - x1) + abs(y3 - y1)
|
||||
dis_3 = abs(x4 - x2) + abs(y4 - y2)
|
||||
return dis + min(dis_2, dis_3)
|
||||
|
||||
|
||||
def compute_iou(rec1, rec2):
|
||||
"""
|
||||
computing IoU
|
||||
:param rec1: (y0, x0, y1, x1), which reflects
|
||||
(top, left, bottom, right)
|
||||
:param rec2: (y0, x0, y1, x1)
|
||||
:return: scala value of IoU
|
||||
"""
|
||||
# computing area of each rectangles
|
||||
S_rec1 = (rec1[2] - rec1[0]) * (rec1[3] - rec1[1])
|
||||
S_rec2 = (rec2[2] - rec2[0]) * (rec2[3] - rec2[1])
|
||||
|
||||
# computing the sum_area
|
||||
sum_area = S_rec1 + S_rec2
|
||||
|
||||
# find the each edge of intersect rectangle
|
||||
left_line = max(rec1[1], rec2[1])
|
||||
right_line = min(rec1[3], rec2[3])
|
||||
top_line = max(rec1[0], rec2[0])
|
||||
bottom_line = min(rec1[2], rec2[2])
|
||||
|
||||
# judge if there is an intersect
|
||||
if left_line >= right_line or top_line >= bottom_line:
|
||||
return 0.0
|
||||
|
||||
intersect = (right_line - left_line) * (bottom_line - top_line)
|
||||
return (intersect / (sum_area - intersect)) * 1.0
|
@ -0,0 +1,15 @@
|
||||
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from .table_structure import TableStructurer
|
||||
from .table_structure_unitable import TableStructureUnitable
|
@ -0,0 +1,58 @@
|
||||
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import time
|
||||
from typing import Any, Dict
|
||||
|
||||
import numpy as np
|
||||
|
||||
from .utils import OrtInferSession, TableLabelDecode, TablePreprocess
|
||||
|
||||
|
||||
class TableStructurer:
|
||||
def __init__(self, config: Dict[str, Any]):
|
||||
self.preprocess_op = TablePreprocess()
|
||||
|
||||
self.session = OrtInferSession(config)
|
||||
|
||||
self.character = self.session.get_metadata()
|
||||
self.postprocess_op = TableLabelDecode(self.character)
|
||||
|
||||
def __call__(self, img):
|
||||
starttime = time.time()
|
||||
data = {"image": img}
|
||||
data = self.preprocess_op(data)
|
||||
img = data[0]
|
||||
if img is None:
|
||||
return None, 0
|
||||
img = np.expand_dims(img, axis=0)
|
||||
img = img.copy()
|
||||
|
||||
outputs = self.session([img])
|
||||
|
||||
preds = {"loc_preds": outputs[0], "structure_probs": outputs[1]}
|
||||
|
||||
shape_list = np.expand_dims(data[-1], axis=0)
|
||||
post_result = self.postprocess_op(preds, [shape_list])
|
||||
|
||||
bbox_list = post_result["bbox_batch_list"][0]
|
||||
|
||||
structure_str_list = post_result["structure_batch_list"][0]
|
||||
structure_str_list = structure_str_list[0]
|
||||
structure_str_list = (
|
||||
["<html>", "<body>", "<table>"]
|
||||
+ structure_str_list
|
||||
+ ["</table>", "</body>", "</html>"]
|
||||
)
|
||||
elapse = time.time() - starttime
|
||||
return structure_str_list, bbox_list, elapse
|
@ -0,0 +1,911 @@
|
||||
from dataclasses import dataclass
|
||||
from functools import partial
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch import Tensor
|
||||
from torch.nn import functional as F
|
||||
from torch.nn.modules.transformer import _get_activation_fn
|
||||
|
||||
TOKEN_WHITE_LIST = [
|
||||
1,
|
||||
12,
|
||||
13,
|
||||
14,
|
||||
15,
|
||||
16,
|
||||
17,
|
||||
18,
|
||||
19,
|
||||
20,
|
||||
21,
|
||||
22,
|
||||
23,
|
||||
24,
|
||||
25,
|
||||
26,
|
||||
27,
|
||||
28,
|
||||
29,
|
||||
30,
|
||||
31,
|
||||
32,
|
||||
33,
|
||||
34,
|
||||
35,
|
||||
36,
|
||||
37,
|
||||
38,
|
||||
39,
|
||||
40,
|
||||
41,
|
||||
42,
|
||||
43,
|
||||
44,
|
||||
45,
|
||||
46,
|
||||
47,
|
||||
48,
|
||||
49,
|
||||
50,
|
||||
51,
|
||||
52,
|
||||
53,
|
||||
54,
|
||||
55,
|
||||
56,
|
||||
57,
|
||||
58,
|
||||
59,
|
||||
60,
|
||||
61,
|
||||
62,
|
||||
63,
|
||||
64,
|
||||
65,
|
||||
66,
|
||||
67,
|
||||
68,
|
||||
69,
|
||||
70,
|
||||
71,
|
||||
72,
|
||||
73,
|
||||
74,
|
||||
75,
|
||||
76,
|
||||
77,
|
||||
78,
|
||||
79,
|
||||
80,
|
||||
81,
|
||||
82,
|
||||
83,
|
||||
84,
|
||||
85,
|
||||
86,
|
||||
87,
|
||||
88,
|
||||
89,
|
||||
90,
|
||||
91,
|
||||
92,
|
||||
93,
|
||||
94,
|
||||
95,
|
||||
96,
|
||||
97,
|
||||
98,
|
||||
99,
|
||||
100,
|
||||
101,
|
||||
102,
|
||||
103,
|
||||
104,
|
||||
105,
|
||||
106,
|
||||
107,
|
||||
108,
|
||||
109,
|
||||
110,
|
||||
111,
|
||||
112,
|
||||
113,
|
||||
114,
|
||||
115,
|
||||
116,
|
||||
117,
|
||||
118,
|
||||
119,
|
||||
120,
|
||||
121,
|
||||
122,
|
||||
123,
|
||||
124,
|
||||
125,
|
||||
126,
|
||||
127,
|
||||
128,
|
||||
129,
|
||||
130,
|
||||
131,
|
||||
132,
|
||||
133,
|
||||
134,
|
||||
135,
|
||||
136,
|
||||
137,
|
||||
138,
|
||||
139,
|
||||
140,
|
||||
141,
|
||||
142,
|
||||
143,
|
||||
144,
|
||||
145,
|
||||
146,
|
||||
147,
|
||||
148,
|
||||
149,
|
||||
150,
|
||||
151,
|
||||
152,
|
||||
153,
|
||||
154,
|
||||
155,
|
||||
156,
|
||||
157,
|
||||
158,
|
||||
159,
|
||||
160,
|
||||
161,
|
||||
162,
|
||||
163,
|
||||
164,
|
||||
165,
|
||||
166,
|
||||
167,
|
||||
168,
|
||||
169,
|
||||
170,
|
||||
171,
|
||||
172,
|
||||
173,
|
||||
174,
|
||||
175,
|
||||
176,
|
||||
177,
|
||||
178,
|
||||
179,
|
||||
180,
|
||||
181,
|
||||
182,
|
||||
183,
|
||||
184,
|
||||
185,
|
||||
186,
|
||||
187,
|
||||
188,
|
||||
189,
|
||||
190,
|
||||
191,
|
||||
192,
|
||||
193,
|
||||
194,
|
||||
195,
|
||||
196,
|
||||
197,
|
||||
198,
|
||||
199,
|
||||
200,
|
||||
201,
|
||||
202,
|
||||
203,
|
||||
204,
|
||||
205,
|
||||
206,
|
||||
207,
|
||||
208,
|
||||
209,
|
||||
210,
|
||||
211,
|
||||
212,
|
||||
213,
|
||||
214,
|
||||
215,
|
||||
216,
|
||||
217,
|
||||
218,
|
||||
219,
|
||||
220,
|
||||
221,
|
||||
222,
|
||||
223,
|
||||
224,
|
||||
225,
|
||||
226,
|
||||
227,
|
||||
228,
|
||||
229,
|
||||
230,
|
||||
231,
|
||||
232,
|
||||
233,
|
||||
234,
|
||||
235,
|
||||
236,
|
||||
237,
|
||||
238,
|
||||
239,
|
||||
240,
|
||||
241,
|
||||
242,
|
||||
243,
|
||||
244,
|
||||
245,
|
||||
246,
|
||||
247,
|
||||
248,
|
||||
249,
|
||||
250,
|
||||
251,
|
||||
252,
|
||||
253,
|
||||
254,
|
||||
255,
|
||||
256,
|
||||
257,
|
||||
258,
|
||||
259,
|
||||
260,
|
||||
261,
|
||||
262,
|
||||
263,
|
||||
264,
|
||||
265,
|
||||
266,
|
||||
267,
|
||||
268,
|
||||
269,
|
||||
270,
|
||||
271,
|
||||
272,
|
||||
273,
|
||||
274,
|
||||
275,
|
||||
276,
|
||||
277,
|
||||
278,
|
||||
279,
|
||||
280,
|
||||
281,
|
||||
282,
|
||||
283,
|
||||
284,
|
||||
285,
|
||||
286,
|
||||
287,
|
||||
288,
|
||||
289,
|
||||
290,
|
||||
291,
|
||||
292,
|
||||
293,
|
||||
294,
|
||||
295,
|
||||
296,
|
||||
297,
|
||||
298,
|
||||
299,
|
||||
300,
|
||||
301,
|
||||
302,
|
||||
303,
|
||||
304,
|
||||
305,
|
||||
306,
|
||||
307,
|
||||
308,
|
||||
309,
|
||||
310,
|
||||
311,
|
||||
312,
|
||||
313,
|
||||
314,
|
||||
315,
|
||||
316,
|
||||
317,
|
||||
318,
|
||||
319,
|
||||
320,
|
||||
321,
|
||||
322,
|
||||
323,
|
||||
324,
|
||||
325,
|
||||
326,
|
||||
327,
|
||||
328,
|
||||
329,
|
||||
330,
|
||||
331,
|
||||
332,
|
||||
333,
|
||||
334,
|
||||
335,
|
||||
336,
|
||||
337,
|
||||
338,
|
||||
339,
|
||||
340,
|
||||
341,
|
||||
342,
|
||||
343,
|
||||
344,
|
||||
345,
|
||||
346,
|
||||
347,
|
||||
348,
|
||||
349,
|
||||
350,
|
||||
351,
|
||||
352,
|
||||
353,
|
||||
354,
|
||||
355,
|
||||
356,
|
||||
357,
|
||||
358,
|
||||
359,
|
||||
360,
|
||||
361,
|
||||
362,
|
||||
363,
|
||||
364,
|
||||
365,
|
||||
366,
|
||||
367,
|
||||
368,
|
||||
369,
|
||||
370,
|
||||
371,
|
||||
372,
|
||||
373,
|
||||
374,
|
||||
375,
|
||||
376,
|
||||
377,
|
||||
378,
|
||||
379,
|
||||
380,
|
||||
381,
|
||||
382,
|
||||
383,
|
||||
384,
|
||||
385,
|
||||
386,
|
||||
387,
|
||||
388,
|
||||
389,
|
||||
390,
|
||||
391,
|
||||
392,
|
||||
393,
|
||||
394,
|
||||
395,
|
||||
396,
|
||||
397,
|
||||
398,
|
||||
399,
|
||||
400,
|
||||
401,
|
||||
402,
|
||||
403,
|
||||
404,
|
||||
405,
|
||||
406,
|
||||
407,
|
||||
408,
|
||||
409,
|
||||
410,
|
||||
411,
|
||||
412,
|
||||
413,
|
||||
414,
|
||||
415,
|
||||
416,
|
||||
417,
|
||||
418,
|
||||
419,
|
||||
420,
|
||||
421,
|
||||
422,
|
||||
423,
|
||||
424,
|
||||
425,
|
||||
426,
|
||||
427,
|
||||
428,
|
||||
429,
|
||||
430,
|
||||
431,
|
||||
432,
|
||||
433,
|
||||
434,
|
||||
435,
|
||||
436,
|
||||
437,
|
||||
438,
|
||||
439,
|
||||
440,
|
||||
441,
|
||||
442,
|
||||
443,
|
||||
444,
|
||||
445,
|
||||
446,
|
||||
447,
|
||||
448,
|
||||
449,
|
||||
450,
|
||||
451,
|
||||
452,
|
||||
453,
|
||||
454,
|
||||
455,
|
||||
456,
|
||||
457,
|
||||
458,
|
||||
459,
|
||||
460,
|
||||
461,
|
||||
462,
|
||||
463,
|
||||
464,
|
||||
465,
|
||||
466,
|
||||
467,
|
||||
468,
|
||||
469,
|
||||
470,
|
||||
471,
|
||||
472,
|
||||
473,
|
||||
474,
|
||||
475,
|
||||
476,
|
||||
477,
|
||||
478,
|
||||
479,
|
||||
480,
|
||||
481,
|
||||
482,
|
||||
483,
|
||||
484,
|
||||
485,
|
||||
486,
|
||||
487,
|
||||
488,
|
||||
489,
|
||||
490,
|
||||
491,
|
||||
492,
|
||||
493,
|
||||
494,
|
||||
495,
|
||||
496,
|
||||
497,
|
||||
498,
|
||||
499,
|
||||
500,
|
||||
501,
|
||||
502,
|
||||
503,
|
||||
504,
|
||||
505,
|
||||
506,
|
||||
507,
|
||||
508,
|
||||
509,
|
||||
]
|
||||
|
||||
|
||||
class ImgLinearBackbone(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
d_model: int,
|
||||
patch_size: int,
|
||||
in_chan: int = 3,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.conv_proj = nn.Conv2d(
|
||||
in_chan,
|
||||
out_channels=d_model,
|
||||
kernel_size=patch_size,
|
||||
stride=patch_size,
|
||||
)
|
||||
self.d_model = d_model
|
||||
|
||||
def forward(self, x: Tensor) -> Tensor:
|
||||
x = self.conv_proj(x)
|
||||
x = x.flatten(start_dim=-2).transpose(1, 2)
|
||||
return x
|
||||
|
||||
|
||||
class Encoder(nn.Module):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.patch_size = 16
|
||||
self.d_model = 768
|
||||
self.dropout = 0
|
||||
self.activation = "gelu"
|
||||
self.norm_first = True
|
||||
self.ff_ratio = 4
|
||||
self.nhead = 12
|
||||
self.max_seq_len = 1024
|
||||
self.n_encoder_layer = 12
|
||||
encoder_layer = nn.TransformerEncoderLayer(
|
||||
self.d_model,
|
||||
nhead=self.nhead,
|
||||
dim_feedforward=self.ff_ratio * self.d_model,
|
||||
dropout=self.dropout,
|
||||
activation=self.activation,
|
||||
batch_first=True,
|
||||
norm_first=self.norm_first,
|
||||
)
|
||||
norm_layer = partial(nn.LayerNorm, eps=1e-6)
|
||||
self.norm = norm_layer(self.d_model)
|
||||
self.backbone = ImgLinearBackbone(
|
||||
d_model=self.d_model, patch_size=self.patch_size
|
||||
)
|
||||
self.pos_embed = PositionEmbedding(
|
||||
max_seq_len=self.max_seq_len, d_model=self.d_model, dropout=self.dropout
|
||||
)
|
||||
self.encoder = nn.TransformerEncoder(
|
||||
encoder_layer, num_layers=self.n_encoder_layer, enable_nested_tensor=False
|
||||
)
|
||||
|
||||
def forward(self, x: Tensor) -> Tensor:
|
||||
src_feature = self.backbone(x)
|
||||
src_feature = self.pos_embed(src_feature)
|
||||
memory = self.encoder(src_feature)
|
||||
memory = self.norm(memory)
|
||||
return memory
|
||||
|
||||
|
||||
class PositionEmbedding(nn.Module):
|
||||
def __init__(self, max_seq_len: int, d_model: int, dropout: float) -> None:
|
||||
super().__init__()
|
||||
self.embedding = nn.Embedding(max_seq_len, d_model)
|
||||
self.dropout = nn.Dropout(dropout)
|
||||
|
||||
def forward(self, x: Tensor, input_pos: Optional[Tensor] = None) -> Tensor:
|
||||
# assume x is batch first
|
||||
if input_pos is None:
|
||||
_pos = torch.arange(x.shape[1], device=x.device)
|
||||
else:
|
||||
_pos = input_pos
|
||||
out = self.embedding(_pos)
|
||||
return self.dropout(out + x)
|
||||
|
||||
|
||||
class TokenEmbedding(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
vocab_size: int,
|
||||
d_model: int,
|
||||
padding_idx: int,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
assert vocab_size > 0
|
||||
self.embedding = nn.Embedding(vocab_size, d_model, padding_idx=padding_idx)
|
||||
|
||||
def forward(self, x: Tensor) -> Tensor:
|
||||
return self.embedding(x)
|
||||
|
||||
|
||||
def find_multiple(n: int, k: int) -> int:
|
||||
if n % k == 0:
|
||||
return n
|
||||
return n + k - (n % k)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ModelArgs:
|
||||
n_layer: int = 4
|
||||
n_head: int = 12
|
||||
dim: int = 768
|
||||
intermediate_size: int = None
|
||||
head_dim: int = 64
|
||||
activation: str = "gelu"
|
||||
norm_first: bool = True
|
||||
|
||||
def __post_init__(self):
|
||||
if self.intermediate_size is None:
|
||||
hidden_dim = 4 * self.dim
|
||||
n_hidden = int(2 * hidden_dim / 3)
|
||||
self.intermediate_size = find_multiple(n_hidden, 256)
|
||||
self.head_dim = self.dim // self.n_head
|
||||
|
||||
|
||||
class KVCache(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
max_batch_size,
|
||||
max_seq_length,
|
||||
n_heads,
|
||||
head_dim,
|
||||
dtype=torch.bfloat16,
|
||||
device="cpu",
|
||||
):
|
||||
super().__init__()
|
||||
cache_shape = (max_batch_size, n_heads, max_seq_length, head_dim)
|
||||
self.register_buffer(
|
||||
"k_cache",
|
||||
torch.zeros(cache_shape, dtype=dtype, device=device),
|
||||
persistent=False,
|
||||
)
|
||||
self.register_buffer(
|
||||
"v_cache",
|
||||
torch.zeros(cache_shape, dtype=dtype, device=device),
|
||||
persistent=False,
|
||||
)
|
||||
|
||||
def update(self, input_pos, k_val, v_val):
|
||||
# input_pos: [S], k_val: [B, H, S, D]
|
||||
# assert input_pos.shape[0] == k_val.shape[2]
|
||||
|
||||
bs = k_val.shape[0]
|
||||
k_out = self.k_cache
|
||||
v_out = self.v_cache
|
||||
k_out[:bs, :, input_pos] = k_val
|
||||
v_out[:bs, :, input_pos] = v_val
|
||||
|
||||
return k_out[:bs], v_out[:bs]
|
||||
|
||||
|
||||
class GPTFastDecoder(nn.Module):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.vocab_size = 960
|
||||
self.padding_idx = 2
|
||||
self.prefix_token_id = 11
|
||||
self.eos_id = 1
|
||||
self.max_seq_len = 1024
|
||||
self.dropout = 0
|
||||
self.d_model = 768
|
||||
self.nhead = 12
|
||||
self.activation = "gelu"
|
||||
self.norm_first = True
|
||||
self.n_decoder_layer = 4
|
||||
config = ModelArgs(
|
||||
n_layer=self.n_decoder_layer,
|
||||
n_head=self.nhead,
|
||||
dim=self.d_model,
|
||||
intermediate_size=self.d_model * 4,
|
||||
activation=self.activation,
|
||||
norm_first=self.norm_first,
|
||||
)
|
||||
self.config = config
|
||||
self.layers = nn.ModuleList(
|
||||
TransformerBlock(config) for _ in range(config.n_layer)
|
||||
)
|
||||
self.token_embed = TokenEmbedding(
|
||||
vocab_size=self.vocab_size,
|
||||
d_model=self.d_model,
|
||||
padding_idx=self.padding_idx,
|
||||
)
|
||||
self.pos_embed = PositionEmbedding(
|
||||
max_seq_len=self.max_seq_len, d_model=self.d_model, dropout=self.dropout
|
||||
)
|
||||
self.generator = nn.Linear(self.d_model, self.vocab_size)
|
||||
self.token_white_list = TOKEN_WHITE_LIST
|
||||
self.mask_cache: Optional[Tensor] = None
|
||||
self.max_batch_size = -1
|
||||
self.max_seq_length = -1
|
||||
|
||||
def setup_caches(self, max_batch_size, max_seq_length, dtype, device):
|
||||
for b in self.layers:
|
||||
b.multihead_attn.k_cache = None
|
||||
b.multihead_attn.v_cache = None
|
||||
|
||||
if (
|
||||
self.max_seq_length >= max_seq_length
|
||||
and self.max_batch_size >= max_batch_size
|
||||
):
|
||||
return
|
||||
head_dim = self.config.dim // self.config.n_head
|
||||
max_seq_length = find_multiple(max_seq_length, 8)
|
||||
self.max_seq_length = max_seq_length
|
||||
self.max_batch_size = max_batch_size
|
||||
|
||||
for b in self.layers:
|
||||
b.self_attn.kv_cache = KVCache(
|
||||
max_batch_size,
|
||||
max_seq_length,
|
||||
self.config.n_head,
|
||||
head_dim,
|
||||
dtype,
|
||||
device,
|
||||
)
|
||||
b.multihead_attn.k_cache = None
|
||||
b.multihead_attn.v_cache = None
|
||||
|
||||
self.causal_mask = torch.tril(
|
||||
torch.ones(self.max_seq_length, self.max_seq_length, dtype=torch.bool)
|
||||
).to(device)
|
||||
|
||||
def forward(self, memory: Tensor, tgt: Tensor) -> Tensor:
|
||||
input_pos = torch.tensor([tgt.shape[1] - 1], device=tgt.device, dtype=torch.int)
|
||||
tgt = tgt[:, -1:]
|
||||
tgt_feature = self.pos_embed(self.token_embed(tgt), input_pos=input_pos)
|
||||
# tgt = self.decoder(tgt_feature, memory, input_pos)
|
||||
with torch.backends.cuda.sdp_kernel(
|
||||
enable_flash=False, enable_mem_efficient=False, enable_math=True
|
||||
):
|
||||
logits = tgt_feature
|
||||
tgt_mask = self.causal_mask[None, None, input_pos]
|
||||
for i, layer in enumerate(self.layers):
|
||||
logits = layer(logits, memory, input_pos=input_pos, tgt_mask=tgt_mask)
|
||||
# return output
|
||||
logits = self.generator(logits)[:, -1, :]
|
||||
total = set([i for i in range(logits.shape[-1])])
|
||||
black_list = list(total.difference(set(self.token_white_list)))
|
||||
logits[..., black_list] = -1e9
|
||||
probs = F.softmax(logits, dim=-1)
|
||||
_, next_tokens = probs.topk(1)
|
||||
return next_tokens
|
||||
|
||||
|
||||
class TransformerBlock(nn.Module):
|
||||
def __init__(self, config: ModelArgs) -> None:
|
||||
super().__init__()
|
||||
self.self_attn = Attention(config)
|
||||
self.multihead_attn = CrossAttention(config)
|
||||
|
||||
layer_norm_eps = 1e-5
|
||||
|
||||
d_model = config.dim
|
||||
dim_feedforward = config.intermediate_size
|
||||
|
||||
self.linear1 = nn.Linear(d_model, dim_feedforward)
|
||||
self.linear2 = nn.Linear(dim_feedforward, d_model)
|
||||
|
||||
self.norm_first = config.norm_first
|
||||
self.norm1 = nn.LayerNorm(d_model, eps=layer_norm_eps)
|
||||
self.norm2 = nn.LayerNorm(d_model, eps=layer_norm_eps)
|
||||
self.norm3 = nn.LayerNorm(d_model, eps=layer_norm_eps)
|
||||
|
||||
self.activation = _get_activation_fn(config.activation)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
tgt: Tensor,
|
||||
memory: Tensor,
|
||||
tgt_mask: Tensor,
|
||||
input_pos: Tensor,
|
||||
) -> Tensor:
|
||||
if self.norm_first:
|
||||
x = tgt
|
||||
x = x + self.self_attn(self.norm1(x), tgt_mask, input_pos)
|
||||
x = x + self.multihead_attn(self.norm2(x), memory)
|
||||
x = x + self._ff_block(self.norm3(x))
|
||||
else:
|
||||
x = tgt
|
||||
x = self.norm1(x + self.self_attn(x, tgt_mask, input_pos))
|
||||
x = self.norm2(x + self.multihead_attn(x, memory))
|
||||
x = self.norm3(x + self._ff_block(x))
|
||||
return x
|
||||
|
||||
def _ff_block(self, x: Tensor) -> Tensor:
|
||||
x = self.linear2(self.activation(self.linear1(x)))
|
||||
return x
|
||||
|
||||
|
||||
class Attention(nn.Module):
|
||||
def __init__(self, config: ModelArgs):
|
||||
super().__init__()
|
||||
assert config.dim % config.n_head == 0
|
||||
|
||||
# key, query, value projections for all heads, but in a batch
|
||||
self.wqkv = nn.Linear(config.dim, 3 * config.dim)
|
||||
self.wo = nn.Linear(config.dim, config.dim)
|
||||
|
||||
self.kv_cache: Optional[KVCache] = None
|
||||
|
||||
self.n_head = config.n_head
|
||||
self.head_dim = config.head_dim
|
||||
self.dim = config.dim
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: Tensor,
|
||||
mask: Tensor,
|
||||
input_pos: Optional[Tensor] = None,
|
||||
) -> Tensor:
|
||||
bsz, seqlen, _ = x.shape
|
||||
|
||||
kv_size = self.n_head * self.head_dim
|
||||
q, k, v = self.wqkv(x).split([self.dim, kv_size, kv_size], dim=-1)
|
||||
|
||||
q = q.view(bsz, seqlen, self.n_head, self.head_dim)
|
||||
k = k.view(bsz, seqlen, self.n_head, self.head_dim)
|
||||
v = v.view(bsz, seqlen, self.n_head, self.head_dim)
|
||||
|
||||
q, k, v = map(lambda x: x.transpose(1, 2), (q, k, v))
|
||||
|
||||
if self.kv_cache is not None:
|
||||
k, v = self.kv_cache.update(input_pos, k, v)
|
||||
|
||||
y = F.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0)
|
||||
|
||||
y = y.transpose(1, 2).contiguous().view(bsz, seqlen, self.dim)
|
||||
|
||||
y = self.wo(y)
|
||||
return y
|
||||
|
||||
|
||||
class CrossAttention(nn.Module):
|
||||
def __init__(self, config: ModelArgs):
|
||||
super().__init__()
|
||||
assert config.dim % config.n_head == 0
|
||||
|
||||
self.query = nn.Linear(config.dim, config.dim)
|
||||
self.key = nn.Linear(config.dim, config.dim)
|
||||
self.value = nn.Linear(config.dim, config.dim)
|
||||
self.out = nn.Linear(config.dim, config.dim)
|
||||
|
||||
self.k_cache = None
|
||||
self.v_cache = None
|
||||
|
||||
self.n_head = config.n_head
|
||||
self.head_dim = config.head_dim
|
||||
|
||||
def get_kv(self, xa: torch.Tensor):
|
||||
if self.k_cache is not None and self.v_cache is not None:
|
||||
return self.k_cache, self.v_cache
|
||||
|
||||
k = self.key(xa)
|
||||
v = self.value(xa)
|
||||
|
||||
# Reshape for correct format
|
||||
batch_size, source_seq_len, _ = k.shape
|
||||
k = k.view(batch_size, source_seq_len, self.n_head, self.head_dim)
|
||||
v = v.view(batch_size, source_seq_len, self.n_head, self.head_dim)
|
||||
|
||||
if self.k_cache is None:
|
||||
self.k_cache = k
|
||||
if self.v_cache is None:
|
||||
self.v_cache = v
|
||||
|
||||
return k, v
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: Tensor,
|
||||
xa: Tensor,
|
||||
):
|
||||
q = self.query(x)
|
||||
batch_size, target_seq_len, _ = q.shape
|
||||
q = q.view(batch_size, target_seq_len, self.n_head, self.head_dim)
|
||||
k, v = self.get_kv(xa)
|
||||
|
||||
q, k, v = map(lambda x: x.transpose(1, 2), (q, k, v))
|
||||
|
||||
wv = F.scaled_dot_product_attention(
|
||||
query=q,
|
||||
key=k,
|
||||
value=v,
|
||||
is_causal=False,
|
||||
)
|
||||
wv = wv.transpose(1, 2).reshape(
|
||||
batch_size,
|
||||
target_seq_len,
|
||||
self.n_head * self.head_dim,
|
||||
)
|
||||
|
||||
return self.out(wv)
|
@ -0,0 +1,544 @@
|
||||
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# -*- encoding: utf-8 -*-
|
||||
# @Author: SWHL
|
||||
# @Contact: liekkaskono@163.com
|
||||
import os
|
||||
import platform
|
||||
import traceback
|
||||
from enum import Enum
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Tuple, Union
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
from onnxruntime import (
|
||||
GraphOptimizationLevel,
|
||||
InferenceSession,
|
||||
SessionOptions,
|
||||
get_available_providers,
|
||||
get_device,
|
||||
)
|
||||
|
||||
from rapid_table.utils import Logger
|
||||
|
||||
|
||||
class EP(Enum):
|
||||
CPU_EP = "CPUExecutionProvider"
|
||||
CUDA_EP = "CUDAExecutionProvider"
|
||||
DIRECTML_EP = "DmlExecutionProvider"
|
||||
|
||||
|
||||
class OrtInferSession:
|
||||
def __init__(self, config: Dict[str, Any]):
|
||||
self.logger = Logger(logger_name=__name__).get_log()
|
||||
|
||||
model_path = config.get("model_path", None)
|
||||
self._verify_model(model_path)
|
||||
|
||||
self.cfg_use_cuda = config.get("use_cuda", None)
|
||||
self.cfg_use_dml = config.get("use_dml", None)
|
||||
|
||||
self.had_providers: List[str] = get_available_providers()
|
||||
EP_list = self._get_ep_list()
|
||||
|
||||
sess_opt = self._init_sess_opts(config)
|
||||
self.session = InferenceSession(
|
||||
model_path,
|
||||
sess_options=sess_opt,
|
||||
providers=EP_list,
|
||||
)
|
||||
self._verify_providers()
|
||||
|
||||
@staticmethod
|
||||
def _init_sess_opts(config: Dict[str, Any]) -> SessionOptions:
|
||||
sess_opt = SessionOptions()
|
||||
sess_opt.log_severity_level = 4
|
||||
sess_opt.enable_cpu_mem_arena = False
|
||||
sess_opt.graph_optimization_level = GraphOptimizationLevel.ORT_ENABLE_ALL
|
||||
|
||||
cpu_nums = os.cpu_count()
|
||||
intra_op_num_threads = config.get("intra_op_num_threads", -1)
|
||||
if intra_op_num_threads != -1 and 1 <= intra_op_num_threads <= cpu_nums:
|
||||
sess_opt.intra_op_num_threads = intra_op_num_threads
|
||||
|
||||
inter_op_num_threads = config.get("inter_op_num_threads", -1)
|
||||
if inter_op_num_threads != -1 and 1 <= inter_op_num_threads <= cpu_nums:
|
||||
sess_opt.inter_op_num_threads = inter_op_num_threads
|
||||
|
||||
return sess_opt
|
||||
|
||||
def get_metadata(self, key: str = "character") -> list:
|
||||
meta_dict = self.session.get_modelmeta().custom_metadata_map
|
||||
content_list = meta_dict[key].splitlines()
|
||||
return content_list
|
||||
|
||||
def _get_ep_list(self) -> List[Tuple[str, Dict[str, Any]]]:
|
||||
cpu_provider_opts = {
|
||||
"arena_extend_strategy": "kSameAsRequested",
|
||||
}
|
||||
EP_list = [(EP.CPU_EP.value, cpu_provider_opts)]
|
||||
|
||||
cuda_provider_opts = {
|
||||
"device_id": 0,
|
||||
"arena_extend_strategy": "kNextPowerOfTwo",
|
||||
"cudnn_conv_algo_search": "EXHAUSTIVE",
|
||||
"do_copy_in_default_stream": True,
|
||||
}
|
||||
self.use_cuda = self._check_cuda()
|
||||
if self.use_cuda:
|
||||
EP_list.insert(0, (EP.CUDA_EP.value, cuda_provider_opts))
|
||||
|
||||
self.use_directml = self._check_dml()
|
||||
if self.use_directml:
|
||||
self.logger.info(
|
||||
"Windows 10 or above detected, try to use DirectML as primary provider"
|
||||
)
|
||||
directml_options = (
|
||||
cuda_provider_opts if self.use_cuda else cpu_provider_opts
|
||||
)
|
||||
EP_list.insert(0, (EP.DIRECTML_EP.value, directml_options))
|
||||
return EP_list
|
||||
|
||||
def _check_cuda(self) -> bool:
|
||||
if not self.cfg_use_cuda:
|
||||
return False
|
||||
|
||||
cur_device = get_device()
|
||||
if cur_device == "GPU" and EP.CUDA_EP.value in self.had_providers:
|
||||
return True
|
||||
|
||||
self.logger.warning(
|
||||
"%s is not in available providers (%s). Use %s inference by default.",
|
||||
EP.CUDA_EP.value,
|
||||
self.had_providers,
|
||||
self.had_providers[0],
|
||||
)
|
||||
self.logger.info("!!!Recommend to use rapidocr_paddle for inference on GPU.")
|
||||
self.logger.info(
|
||||
"(For reference only) If you want to use GPU acceleration, you must do:"
|
||||
)
|
||||
self.logger.info(
|
||||
"First, uninstall all onnxruntime pakcages in current environment."
|
||||
)
|
||||
self.logger.info(
|
||||
"Second, install onnxruntime-gpu by `pip install onnxruntime-gpu`."
|
||||
)
|
||||
self.logger.info(
|
||||
"\tNote the onnxruntime-gpu version must match your cuda and cudnn version."
|
||||
)
|
||||
self.logger.info(
|
||||
"\tYou can refer this link: https://onnxruntime.ai/docs/execution-providers/CUDA-EP.html"
|
||||
)
|
||||
self.logger.info(
|
||||
"Third, ensure %s is in available providers list. e.g. ['CUDAExecutionProvider', 'CPUExecutionProvider']",
|
||||
EP.CUDA_EP.value,
|
||||
)
|
||||
return False
|
||||
|
||||
def _check_dml(self) -> bool:
|
||||
if not self.cfg_use_dml:
|
||||
return False
|
||||
|
||||
cur_os = platform.system()
|
||||
if cur_os != "Windows":
|
||||
self.logger.warning(
|
||||
"DirectML is only supported in Windows OS. The current OS is %s. Use %s inference by default.",
|
||||
cur_os,
|
||||
self.had_providers[0],
|
||||
)
|
||||
return False
|
||||
|
||||
cur_window_version = int(platform.release().split(".")[0])
|
||||
if cur_window_version < 10:
|
||||
self.logger.warning(
|
||||
"DirectML is only supported in Windows 10 and above OS. The current Windows version is %s. Use %s inference by default.",
|
||||
cur_window_version,
|
||||
self.had_providers[0],
|
||||
)
|
||||
return False
|
||||
|
||||
if EP.DIRECTML_EP.value in self.had_providers:
|
||||
return True
|
||||
|
||||
self.logger.warning(
|
||||
"%s is not in available providers (%s). Use %s inference by default.",
|
||||
EP.DIRECTML_EP.value,
|
||||
self.had_providers,
|
||||
self.had_providers[0],
|
||||
)
|
||||
self.logger.info("If you want to use DirectML acceleration, you must do:")
|
||||
self.logger.info(
|
||||
"First, uninstall all onnxruntime pakcages in current environment."
|
||||
)
|
||||
self.logger.info(
|
||||
"Second, install onnxruntime-directml by `pip install onnxruntime-directml`"
|
||||
)
|
||||
self.logger.info(
|
||||
"Third, ensure %s is in available providers list. e.g. ['DmlExecutionProvider', 'CPUExecutionProvider']",
|
||||
EP.DIRECTML_EP.value,
|
||||
)
|
||||
return False
|
||||
|
||||
def _verify_providers(self):
|
||||
session_providers = self.session.get_providers()
|
||||
first_provider = session_providers[0]
|
||||
|
||||
if self.use_cuda and first_provider != EP.CUDA_EP.value:
|
||||
self.logger.warning(
|
||||
"%s is not avaiable for current env, the inference part is automatically shifted to be executed under %s.",
|
||||
EP.CUDA_EP.value,
|
||||
first_provider,
|
||||
)
|
||||
|
||||
if self.use_directml and first_provider != EP.DIRECTML_EP.value:
|
||||
self.logger.warning(
|
||||
"%s is not available for current env, the inference part is automatically shifted to be executed under %s.",
|
||||
EP.DIRECTML_EP.value,
|
||||
first_provider,
|
||||
)
|
||||
|
||||
def __call__(self, input_content: List[np.ndarray]) -> np.ndarray:
|
||||
input_dict = dict(zip(self.get_input_names(), input_content))
|
||||
try:
|
||||
return self.session.run(None, input_dict)
|
||||
except Exception as e:
|
||||
error_info = traceback.format_exc()
|
||||
raise ONNXRuntimeError(error_info) from e
|
||||
|
||||
def get_input_names(self) -> List[str]:
|
||||
return [v.name for v in self.session.get_inputs()]
|
||||
|
||||
def get_output_names(self) -> List[str]:
|
||||
return [v.name for v in self.session.get_outputs()]
|
||||
|
||||
def get_character_list(self, key: str = "character") -> List[str]:
|
||||
meta_dict = self.session.get_modelmeta().custom_metadata_map
|
||||
return meta_dict[key].splitlines()
|
||||
|
||||
def have_key(self, key: str = "character") -> bool:
|
||||
meta_dict = self.session.get_modelmeta().custom_metadata_map
|
||||
if key in meta_dict.keys():
|
||||
return True
|
||||
return False
|
||||
|
||||
@staticmethod
|
||||
def _verify_model(model_path: Union[str, Path, None]):
|
||||
if model_path is None:
|
||||
raise ValueError("model_path is None!")
|
||||
|
||||
model_path = Path(model_path)
|
||||
if not model_path.exists():
|
||||
raise FileNotFoundError(f"{model_path} does not exists.")
|
||||
|
||||
if not model_path.is_file():
|
||||
raise FileExistsError(f"{model_path} is not a file.")
|
||||
|
||||
|
||||
class ONNXRuntimeError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class TableLabelDecode:
|
||||
def __init__(self, dict_character, merge_no_span_structure=True, **kwargs):
|
||||
if merge_no_span_structure:
|
||||
if "<td></td>" not in dict_character:
|
||||
dict_character.append("<td></td>")
|
||||
if "<td>" in dict_character:
|
||||
dict_character.remove("<td>")
|
||||
|
||||
dict_character = self.add_special_char(dict_character)
|
||||
self.dict = {}
|
||||
for i, char in enumerate(dict_character):
|
||||
self.dict[char] = i
|
||||
self.character = dict_character
|
||||
self.td_token = ["<td>", "<td", "<td></td>"]
|
||||
|
||||
def __call__(self, preds, batch=None):
|
||||
structure_probs = preds["structure_probs"]
|
||||
bbox_preds = preds["loc_preds"]
|
||||
shape_list = batch[-1]
|
||||
result = self.decode(structure_probs, bbox_preds, shape_list)
|
||||
if len(batch) == 1: # only contains shape
|
||||
return result
|
||||
|
||||
label_decode_result = self.decode_label(batch)
|
||||
return result, label_decode_result
|
||||
|
||||
def decode(self, structure_probs, bbox_preds, shape_list):
|
||||
"""convert text-label into text-index."""
|
||||
ignored_tokens = self.get_ignored_tokens()
|
||||
end_idx = self.dict[self.end_str]
|
||||
|
||||
structure_idx = structure_probs.argmax(axis=2)
|
||||
structure_probs = structure_probs.max(axis=2)
|
||||
|
||||
structure_batch_list = []
|
||||
bbox_batch_list = []
|
||||
batch_size = len(structure_idx)
|
||||
for batch_idx in range(batch_size):
|
||||
structure_list = []
|
||||
bbox_list = []
|
||||
score_list = []
|
||||
for idx in range(len(structure_idx[batch_idx])):
|
||||
char_idx = int(structure_idx[batch_idx][idx])
|
||||
if idx > 0 and char_idx == end_idx:
|
||||
break
|
||||
|
||||
if char_idx in ignored_tokens:
|
||||
continue
|
||||
|
||||
text = self.character[char_idx]
|
||||
if text in self.td_token:
|
||||
bbox = bbox_preds[batch_idx, idx]
|
||||
bbox = self._bbox_decode(bbox, shape_list[batch_idx])
|
||||
bbox_list.append(bbox)
|
||||
structure_list.append(text)
|
||||
score_list.append(structure_probs[batch_idx, idx])
|
||||
structure_batch_list.append([structure_list, np.mean(score_list)])
|
||||
bbox_batch_list.append(np.array(bbox_list))
|
||||
result = {
|
||||
"bbox_batch_list": bbox_batch_list,
|
||||
"structure_batch_list": structure_batch_list,
|
||||
}
|
||||
return result
|
||||
|
||||
def decode_label(self, batch):
|
||||
"""convert text-label into text-index."""
|
||||
structure_idx = batch[1]
|
||||
gt_bbox_list = batch[2]
|
||||
shape_list = batch[-1]
|
||||
ignored_tokens = self.get_ignored_tokens()
|
||||
end_idx = self.dict[self.end_str]
|
||||
|
||||
structure_batch_list = []
|
||||
bbox_batch_list = []
|
||||
batch_size = len(structure_idx)
|
||||
for batch_idx in range(batch_size):
|
||||
structure_list = []
|
||||
bbox_list = []
|
||||
for idx in range(len(structure_idx[batch_idx])):
|
||||
char_idx = int(structure_idx[batch_idx][idx])
|
||||
if idx > 0 and char_idx == end_idx:
|
||||
break
|
||||
|
||||
if char_idx in ignored_tokens:
|
||||
continue
|
||||
|
||||
structure_list.append(self.character[char_idx])
|
||||
|
||||
bbox = gt_bbox_list[batch_idx][idx]
|
||||
if bbox.sum() != 0:
|
||||
bbox = self._bbox_decode(bbox, shape_list[batch_idx])
|
||||
bbox_list.append(bbox)
|
||||
|
||||
structure_batch_list.append(structure_list)
|
||||
bbox_batch_list.append(bbox_list)
|
||||
result = {
|
||||
"bbox_batch_list": bbox_batch_list,
|
||||
"structure_batch_list": structure_batch_list,
|
||||
}
|
||||
return result
|
||||
|
||||
def _bbox_decode(self, bbox, shape):
|
||||
h, w = shape[:2]
|
||||
bbox[0::2] *= w
|
||||
bbox[1::2] *= h
|
||||
return bbox
|
||||
|
||||
def get_ignored_tokens(self):
|
||||
beg_idx = self.get_beg_end_flag_idx("beg")
|
||||
end_idx = self.get_beg_end_flag_idx("end")
|
||||
return [beg_idx, end_idx]
|
||||
|
||||
def get_beg_end_flag_idx(self, beg_or_end):
|
||||
if beg_or_end == "beg":
|
||||
return np.array(self.dict[self.beg_str])
|
||||
|
||||
if beg_or_end == "end":
|
||||
return np.array(self.dict[self.end_str])
|
||||
|
||||
raise TypeError(f"unsupport type {beg_or_end} in get_beg_end_flag_idx")
|
||||
|
||||
def add_special_char(self, dict_character):
|
||||
self.beg_str = "sos"
|
||||
self.end_str = "eos"
|
||||
dict_character = [self.beg_str] + dict_character + [self.end_str]
|
||||
return dict_character
|
||||
|
||||
|
||||
class TablePreprocess:
|
||||
def __init__(self):
|
||||
self.table_max_len = 488
|
||||
self.build_pre_process_list()
|
||||
self.ops = self.create_operators()
|
||||
|
||||
def __call__(self, data):
|
||||
"""transform"""
|
||||
if self.ops is None:
|
||||
self.ops = []
|
||||
|
||||
for op in self.ops:
|
||||
data = op(data)
|
||||
if data is None:
|
||||
return None
|
||||
return data
|
||||
|
||||
def create_operators(
|
||||
self,
|
||||
):
|
||||
"""
|
||||
create operators based on the config
|
||||
|
||||
Args:
|
||||
params(list): a dict list, used to create some operators
|
||||
"""
|
||||
assert isinstance(
|
||||
self.pre_process_list, list
|
||||
), "operator config should be a list"
|
||||
ops = []
|
||||
for operator in self.pre_process_list:
|
||||
assert (
|
||||
isinstance(operator, dict) and len(operator) == 1
|
||||
), "yaml format error"
|
||||
op_name = list(operator)[0]
|
||||
param = {} if operator[op_name] is None else operator[op_name]
|
||||
op = eval(op_name)(**param)
|
||||
ops.append(op)
|
||||
return ops
|
||||
|
||||
def build_pre_process_list(self):
|
||||
resize_op = {
|
||||
"ResizeTableImage": {
|
||||
"max_len": self.table_max_len,
|
||||
}
|
||||
}
|
||||
pad_op = {
|
||||
"PaddingTableImage": {"size": [self.table_max_len, self.table_max_len]}
|
||||
}
|
||||
normalize_op = {
|
||||
"NormalizeImage": {
|
||||
"std": [0.229, 0.224, 0.225],
|
||||
"mean": [0.485, 0.456, 0.406],
|
||||
"scale": "1./255.",
|
||||
"order": "hwc",
|
||||
}
|
||||
}
|
||||
to_chw_op = {"ToCHWImage": None}
|
||||
keep_keys_op = {"KeepKeys": {"keep_keys": ["image", "shape"]}}
|
||||
self.pre_process_list = [
|
||||
resize_op,
|
||||
normalize_op,
|
||||
pad_op,
|
||||
to_chw_op,
|
||||
keep_keys_op,
|
||||
]
|
||||
|
||||
|
||||
class ResizeTableImage:
|
||||
def __init__(self, max_len, resize_bboxes=False, infer_mode=False):
|
||||
super(ResizeTableImage, self).__init__()
|
||||
self.max_len = max_len
|
||||
self.resize_bboxes = resize_bboxes
|
||||
self.infer_mode = infer_mode
|
||||
|
||||
def __call__(self, data):
|
||||
img = data["image"]
|
||||
height, width = img.shape[0:2]
|
||||
ratio = self.max_len / (max(height, width) * 1.0)
|
||||
resize_h = int(height * ratio)
|
||||
resize_w = int(width * ratio)
|
||||
resize_img = cv2.resize(img, (resize_w, resize_h))
|
||||
if self.resize_bboxes and not self.infer_mode:
|
||||
data["bboxes"] = data["bboxes"] * ratio
|
||||
data["image"] = resize_img
|
||||
data["src_img"] = img
|
||||
data["shape"] = np.array([height, width, ratio, ratio])
|
||||
data["max_len"] = self.max_len
|
||||
return data
|
||||
|
||||
|
||||
class PaddingTableImage:
|
||||
def __init__(self, size, **kwargs):
|
||||
super(PaddingTableImage, self).__init__()
|
||||
self.size = size
|
||||
|
||||
def __call__(self, data):
|
||||
img = data["image"]
|
||||
pad_h, pad_w = self.size
|
||||
padding_img = np.zeros((pad_h, pad_w, 3), dtype=np.float32)
|
||||
height, width = img.shape[0:2]
|
||||
padding_img[0:height, 0:width, :] = img.copy()
|
||||
data["image"] = padding_img
|
||||
shape = data["shape"].tolist()
|
||||
shape.extend([pad_h, pad_w])
|
||||
data["shape"] = np.array(shape)
|
||||
return data
|
||||
|
||||
|
||||
class NormalizeImage:
|
||||
"""normalize image such as substract mean, divide std"""
|
||||
|
||||
def __init__(self, scale=None, mean=None, std=None, order="chw", **kwargs):
|
||||
if isinstance(scale, str):
|
||||
scale = eval(scale)
|
||||
self.scale = np.float32(scale if scale is not None else 1.0 / 255.0)
|
||||
mean = mean if mean is not None else [0.485, 0.456, 0.406]
|
||||
std = std if std is not None else [0.229, 0.224, 0.225]
|
||||
|
||||
shape = (3, 1, 1) if order == "chw" else (1, 1, 3)
|
||||
self.mean = np.array(mean).reshape(shape).astype("float32")
|
||||
self.std = np.array(std).reshape(shape).astype("float32")
|
||||
|
||||
def __call__(self, data):
|
||||
img = np.array(data["image"])
|
||||
assert isinstance(img, np.ndarray), "invalid input 'img' in NormalizeImage"
|
||||
data["image"] = (img.astype("float32") * self.scale - self.mean) / self.std
|
||||
return data
|
||||
|
||||
|
||||
class ToCHWImage:
|
||||
"""convert hwc image to chw image"""
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
pass
|
||||
|
||||
def __call__(self, data):
|
||||
img = np.array(data["image"])
|
||||
data["image"] = img.transpose((2, 0, 1))
|
||||
return data
|
||||
|
||||
|
||||
class KeepKeys:
|
||||
def __init__(self, keep_keys, **kwargs):
|
||||
self.keep_keys = keep_keys
|
||||
|
||||
def __call__(self, data):
|
||||
data_list = []
|
||||
for key in self.keep_keys:
|
||||
data_list.append(data[key])
|
||||
return data_list
|
||||
|
||||
|
||||
def trans_char_ocr_res(ocr_res):
|
||||
word_result = []
|
||||
for res in ocr_res:
|
||||
score = res[2]
|
||||
for word_box, word in zip(res[3], res[4]):
|
||||
word_res = []
|
||||
word_res.append(word_box)
|
||||
word_res.append(word)
|
||||
word_res.append(score)
|
||||
word_result.append(word_res)
|
||||
return word_result
|
@ -0,0 +1,8 @@
|
||||
# -*- encoding: utf-8 -*-
|
||||
# @Author: SWHL
|
||||
# @Contact: liekkaskono@163.com
|
||||
from .download_model import DownloadModel
|
||||
from .load_image import LoadImage
|
||||
from .logger import Logger
|
||||
from .utils import is_url
|
||||
from .vis import VisTable
|
@ -0,0 +1,67 @@
|
||||
import io
|
||||
from pathlib import Path
|
||||
from typing import Optional, Union
|
||||
|
||||
import requests
|
||||
from tqdm import tqdm
|
||||
|
||||
from .logger import Logger
|
||||
|
||||
PROJECT_DIR = Path(__file__).resolve().parent.parent
|
||||
DEFAULT_MODEL_DIR = PROJECT_DIR / "models"
|
||||
|
||||
|
||||
class DownloadModel:
|
||||
logger = Logger(logger_name=__name__).get_log()
|
||||
|
||||
@classmethod
|
||||
def download(
|
||||
cls,
|
||||
model_full_url: Union[str, Path],
|
||||
save_dir: Union[str, Path, None] = None,
|
||||
save_model_name: Optional[str] = None,
|
||||
) -> str:
|
||||
if save_dir is None:
|
||||
save_dir = DEFAULT_MODEL_DIR
|
||||
|
||||
save_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
if save_model_name is None:
|
||||
save_model_name = Path(model_full_url).name
|
||||
|
||||
save_file_path = save_dir / save_model_name
|
||||
if save_file_path.exists():
|
||||
cls.logger.info("%s already exists", save_file_path)
|
||||
return str(save_file_path)
|
||||
|
||||
try:
|
||||
cls.logger.info("Download %s to %s", model_full_url, save_dir)
|
||||
file = cls.download_as_bytes_with_progress(model_full_url, save_model_name)
|
||||
cls.save_file(save_file_path, file)
|
||||
except Exception as exc:
|
||||
raise DownloadModelError from exc
|
||||
return str(save_file_path)
|
||||
|
||||
@staticmethod
|
||||
def download_as_bytes_with_progress(
|
||||
url: Union[str, Path], name: Optional[str] = None
|
||||
) -> bytes:
|
||||
resp = requests.get(str(url), stream=True, allow_redirects=True, timeout=180)
|
||||
total = int(resp.headers.get("content-length", 0))
|
||||
bio = io.BytesIO()
|
||||
with tqdm(
|
||||
desc=name, total=total, unit="b", unit_scale=True, unit_divisor=1024
|
||||
) as pbar:
|
||||
for chunk in resp.iter_content(chunk_size=65536):
|
||||
pbar.update(len(chunk))
|
||||
bio.write(chunk)
|
||||
return bio.getvalue()
|
||||
|
||||
@staticmethod
|
||||
def save_file(save_path: Union[str, Path], file: bytes):
|
||||
with open(save_path, "wb") as f:
|
||||
f.write(file)
|
||||
|
||||
|
||||
class DownloadModelError(Exception):
|
||||
pass
|
@ -0,0 +1,131 @@
|
||||
# -*- encoding: utf-8 -*-
|
||||
# @Author: SWHL
|
||||
# @Contact: liekkaskono@163.com
|
||||
from io import BytesIO
|
||||
from pathlib import Path
|
||||
from typing import Any, Union
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
import requests
|
||||
from PIL import Image, UnidentifiedImageError
|
||||
|
||||
from .utils import is_url
|
||||
|
||||
root_dir = Path(__file__).resolve().parent
|
||||
InputType = Union[str, np.ndarray, bytes, Path, Image.Image]
|
||||
|
||||
|
||||
class LoadImage:
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
def __call__(self, img: InputType) -> np.ndarray:
|
||||
if not isinstance(img, InputType.__args__):
|
||||
raise LoadImageError(
|
||||
f"The img type {type(img)} does not in {InputType.__args__}"
|
||||
)
|
||||
|
||||
origin_img_type = type(img)
|
||||
img = self.load_img(img)
|
||||
img = self.convert_img(img, origin_img_type)
|
||||
return img
|
||||
|
||||
def load_img(self, img: InputType) -> np.ndarray:
|
||||
if isinstance(img, (str, Path)):
|
||||
if is_url(img):
|
||||
img = Image.open(requests.get(img, stream=True, timeout=60).raw)
|
||||
else:
|
||||
self.verify_exist(img)
|
||||
img = Image.open(img)
|
||||
|
||||
try:
|
||||
img = self.img_to_ndarray(img)
|
||||
except UnidentifiedImageError as e:
|
||||
raise LoadImageError(f"cannot identify image file {img}") from e
|
||||
return img
|
||||
|
||||
if isinstance(img, bytes):
|
||||
img = self.img_to_ndarray(Image.open(BytesIO(img)))
|
||||
return img
|
||||
|
||||
if isinstance(img, np.ndarray):
|
||||
return img
|
||||
|
||||
if isinstance(img, Image.Image):
|
||||
return self.img_to_ndarray(img)
|
||||
|
||||
raise LoadImageError(f"{type(img)} is not supported!")
|
||||
|
||||
def img_to_ndarray(self, img: Image.Image) -> np.ndarray:
|
||||
if img.mode == "1":
|
||||
img = img.convert("L")
|
||||
return np.array(img)
|
||||
return np.array(img)
|
||||
|
||||
def convert_img(self, img: np.ndarray, origin_img_type: Any) -> np.ndarray:
|
||||
if img.ndim == 2:
|
||||
return cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)
|
||||
|
||||
if img.ndim == 3:
|
||||
channel = img.shape[2]
|
||||
if channel == 1:
|
||||
return cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)
|
||||
|
||||
if channel == 2:
|
||||
return self.cvt_two_to_three(img)
|
||||
|
||||
if channel == 3:
|
||||
if issubclass(origin_img_type, (str, Path, bytes, Image.Image)):
|
||||
return cv2.cvtColor(img, cv2.COLOR_RGB2BGR)
|
||||
return img
|
||||
|
||||
if channel == 4:
|
||||
return self.cvt_four_to_three(img)
|
||||
|
||||
raise LoadImageError(
|
||||
f"The channel({channel}) of the img is not in [1, 2, 3, 4]"
|
||||
)
|
||||
|
||||
raise LoadImageError(f"The ndim({img.ndim}) of the img is not in [2, 3]")
|
||||
|
||||
@staticmethod
|
||||
def cvt_two_to_three(img: np.ndarray) -> np.ndarray:
|
||||
"""gray + alpha → BGR"""
|
||||
img_gray = img[..., 0]
|
||||
img_bgr = cv2.cvtColor(img_gray, cv2.COLOR_GRAY2BGR)
|
||||
|
||||
img_alpha = img[..., 1]
|
||||
not_a = cv2.bitwise_not(img_alpha)
|
||||
not_a = cv2.cvtColor(not_a, cv2.COLOR_GRAY2BGR)
|
||||
|
||||
new_img = cv2.bitwise_and(img_bgr, img_bgr, mask=img_alpha)
|
||||
new_img = cv2.add(new_img, not_a)
|
||||
return new_img
|
||||
|
||||
@staticmethod
|
||||
def cvt_four_to_three(img: np.ndarray) -> np.ndarray:
|
||||
"""RGBA → BGR"""
|
||||
r, g, b, a = cv2.split(img)
|
||||
new_img = cv2.merge((b, g, r))
|
||||
|
||||
not_a = cv2.bitwise_not(a)
|
||||
not_a = cv2.cvtColor(not_a, cv2.COLOR_GRAY2BGR)
|
||||
|
||||
new_img = cv2.bitwise_and(new_img, new_img, mask=a)
|
||||
|
||||
mean_color = np.mean(new_img)
|
||||
if mean_color <= 0.0:
|
||||
new_img = cv2.add(new_img, not_a)
|
||||
else:
|
||||
new_img = cv2.bitwise_not(new_img)
|
||||
return new_img
|
||||
|
||||
@staticmethod
|
||||
def verify_exist(file_path: Union[str, Path]):
|
||||
if not Path(file_path).exists():
|
||||
raise LoadImageError(f"{file_path} does not exist.")
|
||||
|
||||
|
||||
class LoadImageError(Exception):
|
||||
pass
|
@ -0,0 +1,37 @@
|
||||
# -*- encoding: utf-8 -*-
|
||||
# @Author: Jocker1212
|
||||
# @Contact: xinyijianggo@gmail.com
|
||||
import logging
|
||||
|
||||
import colorlog
|
||||
|
||||
|
||||
class Logger:
|
||||
def __init__(self, log_level=logging.DEBUG, logger_name=None):
|
||||
self.logger = logging.getLogger(logger_name)
|
||||
self.logger.setLevel(log_level)
|
||||
self.logger.propagate = False
|
||||
|
||||
formatter = colorlog.ColoredFormatter(
|
||||
"%(log_color)s[%(levelname)s] %(asctime)s [RapidTable] %(filename)s:%(lineno)d: %(message)s",
|
||||
log_colors={
|
||||
"DEBUG": "cyan",
|
||||
"INFO": "green",
|
||||
"WARNING": "yellow",
|
||||
"ERROR": "red",
|
||||
"CRITICAL": "red,bg_white",
|
||||
},
|
||||
)
|
||||
|
||||
if not self.logger.handlers:
|
||||
console_handler = logging.StreamHandler()
|
||||
console_handler.setFormatter(formatter)
|
||||
|
||||
for handler in self.logger.handlers:
|
||||
self.logger.removeHandler(handler)
|
||||
|
||||
console_handler.setLevel(log_level)
|
||||
self.logger.addHandler(console_handler)
|
||||
|
||||
def get_log(self):
|
||||
return self.logger
|
@ -0,0 +1,12 @@
|
||||
# -*- encoding: utf-8 -*-
|
||||
# @Author: SWHL
|
||||
# @Contact: liekkaskono@163.com
|
||||
from urllib.parse import urlparse
|
||||
|
||||
|
||||
def is_url(url: str) -> bool:
|
||||
try:
|
||||
result = urlparse(url)
|
||||
return all([result.scheme, result.netloc])
|
||||
except Exception:
|
||||
return False
|
@ -0,0 +1,145 @@
|
||||
# -*- encoding: utf-8 -*-
|
||||
# @Author: SWHL
|
||||
# @Contact: liekkaskono@163.com
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import Optional, Union
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
|
||||
from .load_image import LoadImage
|
||||
|
||||
|
||||
class VisTable:
|
||||
def __init__(self):
|
||||
self.load_img = LoadImage()
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
img_path: Union[str, Path],
|
||||
table_results,
|
||||
save_html_path: Optional[str] = None,
|
||||
save_drawed_path: Optional[str] = None,
|
||||
save_logic_path: Optional[str] = None,
|
||||
):
|
||||
if save_html_path:
|
||||
html_with_border = self.insert_border_style(table_results.pred_html)
|
||||
self.save_html(save_html_path, html_with_border)
|
||||
|
||||
table_cell_bboxes = table_results.cell_bboxes
|
||||
if table_cell_bboxes is None:
|
||||
return None
|
||||
|
||||
img = self.load_img(img_path)
|
||||
|
||||
dims_bboxes = table_cell_bboxes.shape[1]
|
||||
if dims_bboxes == 4:
|
||||
drawed_img = self.draw_rectangle(img, table_cell_bboxes)
|
||||
elif dims_bboxes == 8:
|
||||
drawed_img = self.draw_polylines(img, table_cell_bboxes)
|
||||
else:
|
||||
raise ValueError("Shape of table bounding boxes is not between in 4 or 8.")
|
||||
|
||||
if save_drawed_path:
|
||||
self.save_img(save_drawed_path, drawed_img)
|
||||
|
||||
if save_logic_path and table_results.logic_points:
|
||||
polygons = [[box[0], box[1], box[4], box[5]] for box in table_cell_bboxes]
|
||||
self.plot_rec_box_with_logic_info(
|
||||
img, save_logic_path, table_results.logic_points, polygons
|
||||
)
|
||||
return drawed_img
|
||||
|
||||
def insert_border_style(self, table_html_str: str):
|
||||
style_res = """<meta charset="UTF-8"><style>
|
||||
table {
|
||||
border-collapse: collapse;
|
||||
width: 100%;
|
||||
}
|
||||
th, td {
|
||||
border: 1px solid black;
|
||||
padding: 8px;
|
||||
text-align: center;
|
||||
}
|
||||
th {
|
||||
background-color: #f2f2f2;
|
||||
}
|
||||
</style>"""
|
||||
|
||||
prefix_table, suffix_table = table_html_str.split("<body>")
|
||||
html_with_border = f"{prefix_table}{style_res}<body>{suffix_table}"
|
||||
return html_with_border
|
||||
|
||||
def plot_rec_box_with_logic_info(
|
||||
self, img: np.ndarray, output_path, logic_points, sorted_polygons
|
||||
):
|
||||
"""
|
||||
:param img_path
|
||||
:param output_path
|
||||
:param logic_points: [row_start,row_end,col_start,col_end]
|
||||
:param sorted_polygons: [xmin,ymin,xmax,ymax]
|
||||
:return:
|
||||
"""
|
||||
# 读取原图
|
||||
img = cv2.copyMakeBorder(
|
||||
img, 0, 0, 0, 100, cv2.BORDER_CONSTANT, value=[255, 255, 255]
|
||||
)
|
||||
# 绘制 polygons 矩形
|
||||
for idx, polygon in enumerate(sorted_polygons):
|
||||
x0, y0, x1, y1 = polygon[0], polygon[1], polygon[2], polygon[3]
|
||||
x0 = round(x0)
|
||||
y0 = round(y0)
|
||||
x1 = round(x1)
|
||||
y1 = round(y1)
|
||||
cv2.rectangle(img, (x0, y0), (x1, y1), (0, 0, 255), 1)
|
||||
# 增大字体大小和线宽
|
||||
font_scale = 0.9 # 原先是0.5
|
||||
thickness = 1 # 原先是1
|
||||
logic_point = logic_points[idx]
|
||||
cv2.putText(
|
||||
img,
|
||||
f"row: {logic_point[0]}-{logic_point[1]}",
|
||||
(x0 + 3, y0 + 8),
|
||||
cv2.FONT_HERSHEY_PLAIN,
|
||||
font_scale,
|
||||
(0, 0, 255),
|
||||
thickness,
|
||||
)
|
||||
cv2.putText(
|
||||
img,
|
||||
f"col: {logic_point[2]}-{logic_point[3]}",
|
||||
(x0 + 3, y0 + 18),
|
||||
cv2.FONT_HERSHEY_PLAIN,
|
||||
font_scale,
|
||||
(0, 0, 255),
|
||||
thickness,
|
||||
)
|
||||
os.makedirs(os.path.dirname(output_path), exist_ok=True)
|
||||
# 保存绘制后的图像
|
||||
self.save_img(output_path, img)
|
||||
|
||||
@staticmethod
|
||||
def draw_rectangle(img: np.ndarray, boxes: np.ndarray) -> np.ndarray:
|
||||
img_copy = img.copy()
|
||||
for box in boxes.astype(int):
|
||||
x1, y1, x2, y2 = box
|
||||
cv2.rectangle(img_copy, (x1, y1), (x2, y2), (255, 0, 0), 2)
|
||||
return img_copy
|
||||
|
||||
@staticmethod
|
||||
def draw_polylines(img: np.ndarray, points) -> np.ndarray:
|
||||
img_copy = img.copy()
|
||||
for point in points.astype(int):
|
||||
point = point.reshape(4, 2)
|
||||
cv2.polylines(img_copy, [point.astype(int)], True, (255, 0, 0), 2)
|
||||
return img_copy
|
||||
|
||||
@staticmethod
|
||||
def save_img(save_path: Union[str, Path], img: np.ndarray):
|
||||
cv2.imwrite(str(save_path), img)
|
||||
|
||||
@staticmethod
|
||||
def save_html(save_path: Union[str, Path], html: str):
|
||||
with open(save_path, "w", encoding="utf-8") as f:
|
||||
f.write(html)
|
@ -0,0 +1,18 @@
|
||||
from magic_pdf.model.doc_analyze_by_custom_model import doc_analyze
|
||||
from magic_pdf.data.read_api import read_local_images
|
||||
from markdownify import markdownify as md
|
||||
import re
|
||||
|
||||
|
||||
# proc
|
||||
## Create Dataset Instance
|
||||
input_file = "/mnt/research/PaddleOCR/pdf2md_pipeline/s4_content_recognition/all_layouts/207.jpg"
|
||||
|
||||
ds = read_local_images(input_file)[0]
|
||||
|
||||
x = ds.apply(doc_analyze, ocr=True)
|
||||
x = x.pipe_ocr_mode(None)
|
||||
html = x.get_markdown(None)
|
||||
content = md(html)
|
||||
content = re.sub(r'\\([#*_`])', r'\1', content)
|
||||
print(content)
|
@ -0,0 +1,47 @@
|
||||
from typing import List
|
||||
from pdf2image import convert_from_path
|
||||
import os
|
||||
import paddleclas
|
||||
import cv2
|
||||
from .page_detection.utils import PageDetectionResult
|
||||
|
||||
|
||||
paddle_clas_model = paddleclas.PaddleClas(model_name="text_image_orientation")
|
||||
|
||||
def pdf2image(pdf_path, output_dir):
|
||||
if not os.path.isdir(output_dir):
|
||||
os.makedirs(output_dir)
|
||||
images = convert_from_path(pdf_path)
|
||||
for i, image in enumerate(images):
|
||||
image.save(f'{output_dir}/{i + 1}.jpg')
|
||||
|
||||
|
||||
def image_orient_cls(input_data):
|
||||
return paddle_clas_model.predict(input_data)
|
||||
|
||||
|
||||
def page_detection_visual(page_detection_result: PageDetectionResult):
|
||||
img = cv2.imread(page_detection_result.image_path)
|
||||
for box in page_detection_result.boxes:
|
||||
pos = box.pos
|
||||
clsid = box.clsid
|
||||
confidence = box.confidence
|
||||
if clsid == 0:
|
||||
color = (0, 0, 0)
|
||||
text = 'text'
|
||||
elif clsid == 1:
|
||||
color = (255, 0, 0)
|
||||
text = 'title'
|
||||
elif clsid == 2:
|
||||
color = (0, 255, 0)
|
||||
text = 'figure'
|
||||
elif clsid == 4:
|
||||
color = (0, 0, 255)
|
||||
text = 'table'
|
||||
if clsid == 5:
|
||||
color = (255, 0, 255)
|
||||
text = 'table caption'
|
||||
text = f'{text} {confidence}'
|
||||
img = cv2.rectangle(img, (int(pos[0]), int(pos[1])), (int(pos[2]), int(pos[3])), color, 2)
|
||||
cv2.putText(img, text, (int(pos[0]), int(pos[1])), cv2.FONT_HERSHEY_TRIPLEX, 1, color, 2)
|
||||
return img
|
@ -0,0 +1,262 @@
|
||||
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import numpy as np
|
||||
import paddle
|
||||
import paddle.nn as nn
|
||||
from scipy.special import softmax
|
||||
from scipy.interpolate import InterpolatedUnivariateSpline
|
||||
|
||||
|
||||
def line_iou(pred, target, img_w, length=15, aligned=True):
|
||||
'''
|
||||
Calculate the line iou value between predictions and targets
|
||||
Args:
|
||||
pred: lane predictions, shape: (num_pred, 72)
|
||||
target: ground truth, shape: (num_target, 72)
|
||||
img_w: image width
|
||||
length: extended radius
|
||||
aligned: True for iou loss calculation, False for pair-wise ious in assign
|
||||
'''
|
||||
px1 = pred - length
|
||||
px2 = pred + length
|
||||
tx1 = target - length
|
||||
tx2 = target + length
|
||||
|
||||
if aligned:
|
||||
invalid_mask = target
|
||||
ovr = paddle.minimum(px2, tx2) - paddle.maximum(px1, tx1)
|
||||
union = paddle.maximum(px2, tx2) - paddle.minimum(px1, tx1)
|
||||
else:
|
||||
num_pred = pred.shape[0]
|
||||
invalid_mask = target.tile([num_pred, 1, 1])
|
||||
|
||||
ovr = (paddle.minimum(px2[:, None, :], tx2[None, ...]) - paddle.maximum(
|
||||
px1[:, None, :], tx1[None, ...]))
|
||||
union = (paddle.maximum(px2[:, None, :], tx2[None, ...]) -
|
||||
paddle.minimum(px1[:, None, :], tx1[None, ...]))
|
||||
|
||||
invalid_masks = (invalid_mask < 0) | (invalid_mask >= img_w)
|
||||
|
||||
ovr[invalid_masks] = 0.
|
||||
union[invalid_masks] = 0.
|
||||
iou = ovr.sum(axis=-1) / (union.sum(axis=-1) + 1e-9)
|
||||
return iou
|
||||
|
||||
|
||||
class Lane:
|
||||
def __init__(self, points=None, invalid_value=-2., metadata=None):
|
||||
super(Lane, self).__init__()
|
||||
self.curr_iter = 0
|
||||
self.points = points
|
||||
self.invalid_value = invalid_value
|
||||
self.function = InterpolatedUnivariateSpline(
|
||||
points[:, 1], points[:, 0], k=min(3, len(points) - 1))
|
||||
self.min_y = points[:, 1].min() - 0.01
|
||||
self.max_y = points[:, 1].max() + 0.01
|
||||
self.metadata = metadata or {}
|
||||
|
||||
def __repr__(self):
|
||||
return '[Lane]\n' + str(self.points) + '\n[/Lane]'
|
||||
|
||||
def __call__(self, lane_ys):
|
||||
lane_xs = self.function(lane_ys)
|
||||
|
||||
lane_xs[(lane_ys < self.min_y) | (lane_ys > self.max_y
|
||||
)] = self.invalid_value
|
||||
return lane_xs
|
||||
|
||||
def to_array(self, sample_y_range, img_w, img_h):
|
||||
self.sample_y = range(sample_y_range[0], sample_y_range[1],
|
||||
sample_y_range[2])
|
||||
sample_y = self.sample_y
|
||||
img_w, img_h = img_w, img_h
|
||||
ys = np.array(sample_y) / float(img_h)
|
||||
xs = self(ys)
|
||||
valid_mask = (xs >= 0) & (xs < 1)
|
||||
lane_xs = xs[valid_mask] * img_w
|
||||
lane_ys = ys[valid_mask] * img_h
|
||||
lane = np.concatenate(
|
||||
(lane_xs.reshape(-1, 1), lane_ys.reshape(-1, 1)), axis=1)
|
||||
return lane
|
||||
|
||||
def __iter__(self):
|
||||
return self
|
||||
|
||||
def __next__(self):
|
||||
if self.curr_iter < len(self.points):
|
||||
self.curr_iter += 1
|
||||
return self.points[self.curr_iter - 1]
|
||||
self.curr_iter = 0
|
||||
raise StopIteration
|
||||
|
||||
|
||||
class CLRNetPostProcess(object):
|
||||
"""
|
||||
Args:
|
||||
input_shape (int): network input image size
|
||||
ori_shape (int): ori image shape of before padding
|
||||
scale_factor (float): scale factor of ori image
|
||||
enable_mkldnn (bool): whether to open MKLDNN
|
||||
"""
|
||||
|
||||
def __init__(self, img_w, ori_img_h, cut_height, conf_threshold, nms_thres,
|
||||
max_lanes, num_points):
|
||||
self.img_w = img_w
|
||||
self.conf_threshold = conf_threshold
|
||||
self.nms_thres = nms_thres
|
||||
self.max_lanes = max_lanes
|
||||
self.num_points = num_points
|
||||
self.n_strips = num_points - 1
|
||||
self.n_offsets = num_points
|
||||
self.ori_img_h = ori_img_h
|
||||
self.cut_height = cut_height
|
||||
|
||||
self.prior_ys = paddle.linspace(
|
||||
start=1, stop=0, num=self.n_offsets).astype('float64')
|
||||
|
||||
def predictions_to_pred(self, predictions):
|
||||
"""
|
||||
Convert predictions to internal Lane structure for evaluation.
|
||||
"""
|
||||
lanes = []
|
||||
for lane in predictions:
|
||||
lane_xs = lane[6:].clone()
|
||||
start = min(
|
||||
max(0, int(round(lane[2].item() * self.n_strips))),
|
||||
self.n_strips)
|
||||
length = int(round(lane[5].item()))
|
||||
end = start + length - 1
|
||||
end = min(end, len(self.prior_ys) - 1)
|
||||
if start > 0:
|
||||
mask = ((lane_xs[:start] >= 0.) &
|
||||
(lane_xs[:start] <= 1.)).cpu().detach().numpy()[::-1]
|
||||
mask = ~((mask.cumprod()[::-1]).astype(np.bool_))
|
||||
lane_xs[:start][mask] = -2
|
||||
if end < len(self.prior_ys) - 1:
|
||||
lane_xs[end + 1:] = -2
|
||||
|
||||
lane_ys = self.prior_ys[lane_xs >= 0].clone()
|
||||
lane_xs = lane_xs[lane_xs >= 0]
|
||||
lane_xs = lane_xs.flip(axis=0).astype('float64')
|
||||
lane_ys = lane_ys.flip(axis=0)
|
||||
|
||||
lane_ys = (lane_ys *
|
||||
(self.ori_img_h - self.cut_height) + self.cut_height
|
||||
) / self.ori_img_h
|
||||
if len(lane_xs) <= 1:
|
||||
continue
|
||||
points = paddle.stack(
|
||||
x=(lane_xs.reshape([-1, 1]), lane_ys.reshape([-1, 1])),
|
||||
axis=1).squeeze(axis=2)
|
||||
lane = Lane(
|
||||
points=points.cpu().numpy(),
|
||||
metadata={
|
||||
'start_x': lane[3],
|
||||
'start_y': lane[2],
|
||||
'conf': lane[1]
|
||||
})
|
||||
lanes.append(lane)
|
||||
return lanes
|
||||
|
||||
def lane_nms(self, predictions, scores, nms_overlap_thresh, top_k):
|
||||
"""
|
||||
NMS for lane detection.
|
||||
predictions: paddle.Tensor [num_lanes,conf,y,x,lenght,72offsets] [12,77]
|
||||
scores: paddle.Tensor [num_lanes]
|
||||
nms_overlap_thresh: float
|
||||
top_k: int
|
||||
"""
|
||||
# sort by scores to get idx
|
||||
idx = scores.argsort(descending=True)
|
||||
keep = []
|
||||
|
||||
condidates = predictions.clone()
|
||||
condidates = condidates.index_select(idx)
|
||||
|
||||
while len(condidates) > 0:
|
||||
keep.append(idx[0])
|
||||
if len(keep) >= top_k or len(condidates) == 1:
|
||||
break
|
||||
|
||||
ious = []
|
||||
for i in range(1, len(condidates)):
|
||||
ious.append(1 - line_iou(
|
||||
condidates[i].unsqueeze(0),
|
||||
condidates[0].unsqueeze(0),
|
||||
img_w=self.img_w,
|
||||
length=15))
|
||||
ious = paddle.to_tensor(ious)
|
||||
|
||||
mask = ious <= nms_overlap_thresh
|
||||
id = paddle.where(mask == False)[0]
|
||||
|
||||
if id.shape[0] == 0:
|
||||
break
|
||||
condidates = condidates[1:].index_select(id)
|
||||
idx = idx[1:].index_select(id)
|
||||
keep = paddle.stack(keep)
|
||||
|
||||
return keep
|
||||
|
||||
def get_lanes(self, output, as_lanes=True):
|
||||
"""
|
||||
Convert model output to lanes.
|
||||
"""
|
||||
softmax = nn.Softmax(axis=1)
|
||||
decoded = []
|
||||
|
||||
for predictions in output:
|
||||
if len(predictions) == 0:
|
||||
decoded.append([])
|
||||
continue
|
||||
threshold = self.conf_threshold
|
||||
scores = softmax(predictions[:, :2])[:, 1]
|
||||
keep_inds = scores >= threshold
|
||||
predictions = predictions[keep_inds]
|
||||
scores = scores[keep_inds]
|
||||
|
||||
if predictions.shape[0] == 0:
|
||||
decoded.append([])
|
||||
continue
|
||||
nms_predictions = predictions.detach().clone()
|
||||
nms_predictions = paddle.concat(
|
||||
x=[nms_predictions[..., :4], nms_predictions[..., 5:]], axis=-1)
|
||||
|
||||
nms_predictions[..., 4] = nms_predictions[..., 4] * self.n_strips
|
||||
nms_predictions[..., 5:] = nms_predictions[..., 5:] * (
|
||||
self.img_w - 1)
|
||||
|
||||
keep = self.lane_nms(
|
||||
nms_predictions[..., 5:],
|
||||
scores,
|
||||
nms_overlap_thresh=self.nms_thres,
|
||||
top_k=self.max_lanes)
|
||||
|
||||
predictions = predictions.index_select(keep)
|
||||
|
||||
if predictions.shape[0] == 0:
|
||||
decoded.append([])
|
||||
continue
|
||||
predictions[:, 5] = paddle.round(predictions[:, 5] * self.n_strips)
|
||||
if as_lanes:
|
||||
pred = self.predictions_to_pred(predictions)
|
||||
else:
|
||||
pred = predictions
|
||||
decoded.append(pred)
|
||||
return decoded
|
||||
|
||||
def __call__(self, lanes_list):
|
||||
lanes = self.get_lanes(lanes_list)
|
||||
return lanes
|
@ -0,0 +1,243 @@
|
||||
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
this code is based on https://github.com/open-mmlab/mmpose/mmpose/core/post_processing/post_transforms.py
|
||||
"""
|
||||
import cv2
|
||||
import numpy as np
|
||||
|
||||
|
||||
class EvalAffine(object):
|
||||
def __init__(self, size, stride=64):
|
||||
super(EvalAffine, self).__init__()
|
||||
self.size = size
|
||||
self.stride = stride
|
||||
|
||||
def __call__(self, image, im_info):
|
||||
s = self.size
|
||||
h, w, _ = image.shape
|
||||
trans, size_resized = get_affine_mat_kernel(h, w, s, inv=False)
|
||||
image_resized = cv2.warpAffine(image, trans, size_resized)
|
||||
return image_resized, im_info
|
||||
|
||||
|
||||
def get_affine_mat_kernel(h, w, s, inv=False):
|
||||
if w < h:
|
||||
w_ = s
|
||||
h_ = int(np.ceil((s / w * h) / 64.) * 64)
|
||||
scale_w = w
|
||||
scale_h = h_ / w_ * w
|
||||
|
||||
else:
|
||||
h_ = s
|
||||
w_ = int(np.ceil((s / h * w) / 64.) * 64)
|
||||
scale_h = h
|
||||
scale_w = w_ / h_ * h
|
||||
|
||||
center = np.array([np.round(w / 2.), np.round(h / 2.)])
|
||||
|
||||
size_resized = (w_, h_)
|
||||
trans = get_affine_transform(
|
||||
center, np.array([scale_w, scale_h]), 0, size_resized, inv=inv)
|
||||
|
||||
return trans, size_resized
|
||||
|
||||
|
||||
def get_affine_transform(center,
|
||||
input_size,
|
||||
rot,
|
||||
output_size,
|
||||
shift=(0., 0.),
|
||||
inv=False):
|
||||
"""Get the affine transform matrix, given the center/scale/rot/output_size.
|
||||
|
||||
Args:
|
||||
center (np.ndarray[2, ]): Center of the bounding box (x, y).
|
||||
scale (np.ndarray[2, ]): Scale of the bounding box
|
||||
wrt [width, height].
|
||||
rot (float): Rotation angle (degree).
|
||||
output_size (np.ndarray[2, ]): Size of the destination heatmaps.
|
||||
shift (0-100%): Shift translation ratio wrt the width/height.
|
||||
Default (0., 0.).
|
||||
inv (bool): Option to inverse the affine transform direction.
|
||||
(inv=False: src->dst or inv=True: dst->src)
|
||||
|
||||
Returns:
|
||||
np.ndarray: The transform matrix.
|
||||
"""
|
||||
assert len(center) == 2
|
||||
assert len(output_size) == 2
|
||||
assert len(shift) == 2
|
||||
if not isinstance(input_size, (np.ndarray, list)):
|
||||
input_size = np.array([input_size, input_size], dtype=np.float32)
|
||||
scale_tmp = input_size
|
||||
|
||||
shift = np.array(shift)
|
||||
src_w = scale_tmp[0]
|
||||
dst_w = output_size[0]
|
||||
dst_h = output_size[1]
|
||||
|
||||
rot_rad = np.pi * rot / 180
|
||||
src_dir = rotate_point([0., src_w * -0.5], rot_rad)
|
||||
dst_dir = np.array([0., dst_w * -0.5])
|
||||
|
||||
src = np.zeros((3, 2), dtype=np.float32)
|
||||
src[0, :] = center + scale_tmp * shift
|
||||
src[1, :] = center + src_dir + scale_tmp * shift
|
||||
src[2, :] = _get_3rd_point(src[0, :], src[1, :])
|
||||
|
||||
dst = np.zeros((3, 2), dtype=np.float32)
|
||||
dst[0, :] = [dst_w * 0.5, dst_h * 0.5]
|
||||
dst[1, :] = np.array([dst_w * 0.5, dst_h * 0.5]) + dst_dir
|
||||
dst[2, :] = _get_3rd_point(dst[0, :], dst[1, :])
|
||||
|
||||
if inv:
|
||||
trans = cv2.getAffineTransform(np.float32(dst), np.float32(src))
|
||||
else:
|
||||
trans = cv2.getAffineTransform(np.float32(src), np.float32(dst))
|
||||
|
||||
return trans
|
||||
|
||||
|
||||
def get_warp_matrix(theta, size_input, size_dst, size_target):
|
||||
"""This code is based on
|
||||
https://github.com/open-mmlab/mmpose/blob/master/mmpose/core/post_processing/post_transforms.py
|
||||
|
||||
Calculate the transformation matrix under the constraint of unbiased.
|
||||
Paper ref: Huang et al. The Devil is in the Details: Delving into Unbiased
|
||||
Data Processing for Human Pose Estimation (CVPR 2020).
|
||||
|
||||
Args:
|
||||
theta (float): Rotation angle in degrees.
|
||||
size_input (np.ndarray): Size of input image [w, h].
|
||||
size_dst (np.ndarray): Size of output image [w, h].
|
||||
size_target (np.ndarray): Size of ROI in input plane [w, h].
|
||||
|
||||
Returns:
|
||||
matrix (np.ndarray): A matrix for transformation.
|
||||
"""
|
||||
theta = np.deg2rad(theta)
|
||||
matrix = np.zeros((2, 3), dtype=np.float32)
|
||||
scale_x = size_dst[0] / size_target[0]
|
||||
scale_y = size_dst[1] / size_target[1]
|
||||
matrix[0, 0] = np.cos(theta) * scale_x
|
||||
matrix[0, 1] = -np.sin(theta) * scale_x
|
||||
matrix[0, 2] = scale_x * (
|
||||
-0.5 * size_input[0] * np.cos(theta) + 0.5 * size_input[1] *
|
||||
np.sin(theta) + 0.5 * size_target[0])
|
||||
matrix[1, 0] = np.sin(theta) * scale_y
|
||||
matrix[1, 1] = np.cos(theta) * scale_y
|
||||
matrix[1, 2] = scale_y * (
|
||||
-0.5 * size_input[0] * np.sin(theta) - 0.5 * size_input[1] *
|
||||
np.cos(theta) + 0.5 * size_target[1])
|
||||
return matrix
|
||||
|
||||
|
||||
def rotate_point(pt, angle_rad):
|
||||
"""Rotate a point by an angle.
|
||||
|
||||
Args:
|
||||
pt (list[float]): 2 dimensional point to be rotated
|
||||
angle_rad (float): rotation angle by radian
|
||||
|
||||
Returns:
|
||||
list[float]: Rotated point.
|
||||
"""
|
||||
assert len(pt) == 2
|
||||
sn, cs = np.sin(angle_rad), np.cos(angle_rad)
|
||||
new_x = pt[0] * cs - pt[1] * sn
|
||||
new_y = pt[0] * sn + pt[1] * cs
|
||||
rotated_pt = [new_x, new_y]
|
||||
|
||||
return rotated_pt
|
||||
|
||||
|
||||
def _get_3rd_point(a, b):
|
||||
"""To calculate the affine matrix, three pairs of points are required. This
|
||||
function is used to get the 3rd point, given 2D points a & b.
|
||||
|
||||
The 3rd point is defined by rotating vector `a - b` by 90 degrees
|
||||
anticlockwise, using b as the rotation center.
|
||||
|
||||
Args:
|
||||
a (np.ndarray): point(x,y)
|
||||
b (np.ndarray): point(x,y)
|
||||
|
||||
Returns:
|
||||
np.ndarray: The 3rd point.
|
||||
"""
|
||||
assert len(a) == 2
|
||||
assert len(b) == 2
|
||||
direction = a - b
|
||||
third_pt = b + np.array([-direction[1], direction[0]], dtype=np.float32)
|
||||
|
||||
return third_pt
|
||||
|
||||
|
||||
class TopDownEvalAffine(object):
|
||||
"""apply affine transform to image and coords
|
||||
|
||||
Args:
|
||||
trainsize (list): [w, h], the standard size used to train
|
||||
use_udp (bool): whether to use Unbiased Data Processing.
|
||||
records(dict): the dict contained the image and coords
|
||||
|
||||
Returns:
|
||||
records (dict): contain the image and coords after tranformed
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, trainsize, use_udp=False):
|
||||
self.trainsize = trainsize
|
||||
self.use_udp = use_udp
|
||||
|
||||
def __call__(self, image, im_info):
|
||||
rot = 0
|
||||
imshape = im_info['im_shape'][::-1]
|
||||
center = im_info['center'] if 'center' in im_info else imshape / 2.
|
||||
scale = im_info['scale'] if 'scale' in im_info else imshape
|
||||
if self.use_udp:
|
||||
trans = get_warp_matrix(
|
||||
rot, center * 2.0,
|
||||
[self.trainsize[0] - 1.0, self.trainsize[1] - 1.0], scale)
|
||||
image = cv2.warpAffine(
|
||||
image,
|
||||
trans, (int(self.trainsize[0]), int(self.trainsize[1])),
|
||||
flags=cv2.INTER_LINEAR)
|
||||
else:
|
||||
trans = get_affine_transform(center, scale, rot, self.trainsize)
|
||||
image = cv2.warpAffine(
|
||||
image,
|
||||
trans, (int(self.trainsize[0]), int(self.trainsize[1])),
|
||||
flags=cv2.INTER_LINEAR)
|
||||
|
||||
return image, im_info
|
||||
|
||||
|
||||
def expand_crop(images, rect, expand_ratio=0.3):
|
||||
imgh, imgw, c = images.shape
|
||||
label, conf, xmin, ymin, xmax, ymax = [int(x) for x in rect.tolist()]
|
||||
if label != 0:
|
||||
return None, None, None
|
||||
org_rect = [xmin, ymin, xmax, ymax]
|
||||
h_half = (ymax - ymin) * (1 + expand_ratio) / 2.
|
||||
w_half = (xmax - xmin) * (1 + expand_ratio) / 2.
|
||||
if h_half > w_half * 4 / 3:
|
||||
w_half = h_half * 0.75
|
||||
center = [(ymin + ymax) / 2., (xmin + xmax) / 2.]
|
||||
ymin = max(0, int(center[0] - h_half))
|
||||
ymax = min(imgh - 1, int(center[0] + h_half))
|
||||
xmin = max(0, int(center[1] - w_half))
|
||||
xmax = min(imgw - 1, int(center[1] + w_half))
|
||||
return images[ymin:ymax, xmin:xmax, :], [xmin, ymin, xmax, ymax], org_rect
|
@ -0,0 +1,918 @@
|
||||
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import os
|
||||
import yaml
|
||||
import glob
|
||||
import json
|
||||
from pathlib import Path
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
import math
|
||||
import paddle
|
||||
from paddle.inference import Config
|
||||
from paddle.inference import create_predictor
|
||||
|
||||
import sys
|
||||
# add deploy path of PaddleDetection to sys.path
|
||||
parent_path = os.path.abspath(os.path.join(__file__, *(['..'])))
|
||||
sys.path.insert(0, parent_path)
|
||||
|
||||
from preprocess import preprocess, Resize, NormalizeImage, Permute, PadStride, LetterBoxResize, WarpAffine, Pad, decode_image, CULaneResize
|
||||
from picodet_postprocess import PicoDetPostProcess
|
||||
from clrnet_postprocess import CLRNetPostProcess
|
||||
from visualize import visualize_box_mask, imshow_lanes
|
||||
from utils import argsparser, Timer, multiclass_nms, coco_clsid2catid
|
||||
|
||||
# Global dictionary
|
||||
SUPPORT_MODELS = {
|
||||
'YOLO', 'PPYOLOE', 'RCNN', 'SSD', 'Face', 'FCOS', 'SOLOv2', 'TTFNet',
|
||||
'S2ANet', 'JDE', 'FairMOT', 'DeepSORT', 'GFL', 'PicoDet', 'CenterNet',
|
||||
'TOOD', 'RetinaNet', 'StrongBaseline', 'STGCN', 'YOLOX', 'YOLOF', 'PPHGNet',
|
||||
'PPLCNet', 'DETR', 'CenterTrack', 'CLRNet'
|
||||
}
|
||||
|
||||
|
||||
|
||||
class Detector(object):
|
||||
"""
|
||||
Args:
|
||||
pred_config (object): config of model, defined by `Config(model_dir)`
|
||||
model_dir (str): root path of model.pdiparams, model.pdmodel and infer_cfg.yml
|
||||
device (str): Choose the device you want to run, it can be: CPU/GPU/XPU, default is CPU
|
||||
run_mode (str): mode of running(paddle/trt_fp32/trt_fp16)
|
||||
batch_size (int): size of pre batch in inference
|
||||
trt_min_shape (int): min shape for dynamic shape in trt
|
||||
trt_max_shape (int): max shape for dynamic shape in trt
|
||||
trt_opt_shape (int): opt shape for dynamic shape in trt
|
||||
trt_calib_mode (bool): If the model is produced by TRT offline quantitative
|
||||
calibration, trt_calib_mode need to set True
|
||||
cpu_threads (int): cpu threads
|
||||
enable_mkldnn (bool): whether to open MKLDNN
|
||||
enable_mkldnn_bfloat16 (bool): whether to turn on mkldnn bfloat16
|
||||
output_dir (str): The path of output
|
||||
threshold (float): The threshold of score for visualization
|
||||
delete_shuffle_pass (bool): whether to remove shuffle_channel_detect_pass in TensorRT.
|
||||
Used by action model.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
model_dir,
|
||||
device='CPU',
|
||||
run_mode='paddle',
|
||||
batch_size=1,
|
||||
trt_min_shape=1,
|
||||
trt_max_shape=1280,
|
||||
trt_opt_shape=640,
|
||||
trt_calib_mode=False,
|
||||
cpu_threads=1,
|
||||
enable_mkldnn=False,
|
||||
enable_mkldnn_bfloat16=False,
|
||||
output_dir='output',
|
||||
threshold=0.5,
|
||||
delete_shuffle_pass=False,
|
||||
use_fd_format=False):
|
||||
self.pred_config = self.set_config(
|
||||
model_dir, use_fd_format=use_fd_format)
|
||||
self.predictor, self.config = load_predictor(
|
||||
model_dir,
|
||||
self.pred_config.arch,
|
||||
run_mode=run_mode,
|
||||
batch_size=batch_size,
|
||||
min_subgraph_size=self.pred_config.min_subgraph_size,
|
||||
device=device,
|
||||
use_dynamic_shape=self.pred_config.use_dynamic_shape,
|
||||
trt_min_shape=trt_min_shape,
|
||||
trt_max_shape=trt_max_shape,
|
||||
trt_opt_shape=trt_opt_shape,
|
||||
trt_calib_mode=trt_calib_mode,
|
||||
cpu_threads=cpu_threads,
|
||||
enable_mkldnn=enable_mkldnn,
|
||||
enable_mkldnn_bfloat16=enable_mkldnn_bfloat16,
|
||||
delete_shuffle_pass=delete_shuffle_pass)
|
||||
self.det_times = Timer()
|
||||
self.batch_size = batch_size
|
||||
self.output_dir = output_dir
|
||||
self.threshold = threshold
|
||||
self.device = device
|
||||
|
||||
def set_config(self, model_dir, use_fd_format):
|
||||
return PredictConfig(model_dir, use_fd_format=use_fd_format)
|
||||
|
||||
def preprocess(self, image_list):
|
||||
preprocess_ops = []
|
||||
for op_info in self.pred_config.preprocess_infos:
|
||||
new_op_info = op_info.copy()
|
||||
op_type = new_op_info.pop('type')
|
||||
preprocess_ops.append(eval(op_type)(**new_op_info))
|
||||
|
||||
input_im_lst = []
|
||||
input_im_info_lst = []
|
||||
for im_path in image_list:
|
||||
im, im_info = preprocess(im_path, preprocess_ops)
|
||||
input_im_lst.append(im)
|
||||
input_im_info_lst.append(im_info)
|
||||
inputs = create_inputs(input_im_lst, input_im_info_lst)
|
||||
input_names = self.predictor.get_input_names()
|
||||
for i in range(len(input_names)):
|
||||
input_tensor = self.predictor.get_input_handle(input_names[i])
|
||||
if input_names[i] == 'x':
|
||||
input_tensor.copy_from_cpu(inputs['image'])
|
||||
else:
|
||||
input_tensor.copy_from_cpu(inputs[input_names[i]])
|
||||
|
||||
return inputs
|
||||
|
||||
def postprocess(self, inputs, result):
|
||||
# postprocess output of predictor
|
||||
np_boxes_num = result['boxes_num']
|
||||
assert isinstance(np_boxes_num, np.ndarray), \
|
||||
'`np_boxes_num` should be a `numpy.ndarray`'
|
||||
|
||||
result = {k: v for k, v in result.items() if v is not None}
|
||||
return result
|
||||
|
||||
def filter_box(self, result, threshold):
|
||||
np_boxes_num = result['boxes_num']
|
||||
boxes = result['boxes']
|
||||
start_idx = 0
|
||||
filter_boxes = []
|
||||
filter_num = []
|
||||
for i in range(len(np_boxes_num)):
|
||||
boxes_num = np_boxes_num[i]
|
||||
boxes_i = boxes[start_idx:start_idx + boxes_num, :]
|
||||
idx = boxes_i[:, 1] > threshold
|
||||
filter_boxes_i = boxes_i[idx, :]
|
||||
filter_boxes.append(filter_boxes_i)
|
||||
filter_num.append(filter_boxes_i.shape[0])
|
||||
start_idx += boxes_num
|
||||
boxes = np.concatenate(filter_boxes)
|
||||
filter_num = np.array(filter_num)
|
||||
filter_res = {'boxes': boxes, 'boxes_num': filter_num}
|
||||
return filter_res
|
||||
|
||||
def predict(self, repeats=1, run_benchmark=False):
|
||||
'''
|
||||
Args:
|
||||
repeats (int): repeats number for prediction
|
||||
Returns:
|
||||
result (dict): include 'boxes': np.ndarray: shape:[N,6], N: number of box,
|
||||
matix element:[class, score, x_min, y_min, x_max, y_max]
|
||||
MaskRCNN's result include 'masks': np.ndarray:
|
||||
shape: [N, im_h, im_w]
|
||||
'''
|
||||
# model prediction
|
||||
np_boxes_num, np_boxes, np_masks = np.array([0]), None, None
|
||||
|
||||
if run_benchmark:
|
||||
for i in range(repeats):
|
||||
self.predictor.run()
|
||||
if self.device == 'GPU':
|
||||
paddle.device.cuda.synchronize()
|
||||
else:
|
||||
paddle.device.synchronize(device=self.device.lower())
|
||||
|
||||
result = dict(
|
||||
boxes=np_boxes, masks=np_masks, boxes_num=np_boxes_num)
|
||||
return result
|
||||
|
||||
for i in range(repeats):
|
||||
self.predictor.run()
|
||||
output_names = self.predictor.get_output_names()
|
||||
boxes_tensor = self.predictor.get_output_handle(output_names[0])
|
||||
np_boxes = boxes_tensor.copy_to_cpu()
|
||||
if len(output_names) == 1:
|
||||
# some exported model can not get tensor 'bbox_num'
|
||||
np_boxes_num = np.array([len(np_boxes)])
|
||||
else:
|
||||
boxes_num = self.predictor.get_output_handle(output_names[1])
|
||||
np_boxes_num = boxes_num.copy_to_cpu()
|
||||
if self.pred_config.mask:
|
||||
masks_tensor = self.predictor.get_output_handle(output_names[2])
|
||||
np_masks = masks_tensor.copy_to_cpu()
|
||||
result = dict(boxes=np_boxes, masks=np_masks, boxes_num=np_boxes_num)
|
||||
return result
|
||||
|
||||
def merge_batch_result(self, batch_result):
|
||||
if len(batch_result) == 1:
|
||||
return batch_result[0]
|
||||
res_key = batch_result[0].keys()
|
||||
results = {k: [] for k in res_key}
|
||||
for res in batch_result:
|
||||
for k, v in res.items():
|
||||
results[k].append(v)
|
||||
for k, v in results.items():
|
||||
if k not in ['masks', 'segm']:
|
||||
results[k] = np.concatenate(v)
|
||||
return results
|
||||
|
||||
def get_timer(self):
|
||||
return self.det_times
|
||||
|
||||
def predict_image(self,
|
||||
image_list,
|
||||
threshold=0.5,
|
||||
visual=True):
|
||||
batch_loop_cnt = math.ceil(float(len(image_list)) / self.batch_size)
|
||||
results = []
|
||||
for i in range(batch_loop_cnt):
|
||||
start_index = i * self.batch_size
|
||||
end_index = min((i + 1) * self.batch_size, len(image_list))
|
||||
batch_image_list = image_list[start_index:end_index]
|
||||
# preprocess
|
||||
self.det_times.preprocess_time_s.start()
|
||||
inputs = self.preprocess(batch_image_list)
|
||||
self.det_times.preprocess_time_s.end()
|
||||
|
||||
# model prediction
|
||||
self.det_times.inference_time_s.start()
|
||||
result = self.predict()
|
||||
self.det_times.inference_time_s.end()
|
||||
|
||||
# postprocess
|
||||
self.det_times.postprocess_time_s.start()
|
||||
result = self.postprocess(inputs, result)
|
||||
self.det_times.postprocess_time_s.end()
|
||||
self.det_times.img_num += len(batch_image_list)
|
||||
|
||||
if visual:
|
||||
visualize(
|
||||
batch_image_list,
|
||||
result,
|
||||
self.pred_config.labels,
|
||||
output_dir=self.output_dir,
|
||||
threshold=self.threshold)
|
||||
# TODO 在这里处理batch
|
||||
results.append(result)
|
||||
results = self.merge_batch_result(results)
|
||||
boxes = results['boxes']
|
||||
expect_boxes = (boxes[:, 1] > threshold) & (boxes[:, 0] > -1)
|
||||
boxes = boxes[expect_boxes, :]
|
||||
output = []
|
||||
for dt in boxes:
|
||||
clsid, box, confidence = int(dt[0]), dt[2:].tolist(), dt[1]
|
||||
output.append((clsid, box, confidence))
|
||||
return output
|
||||
|
||||
|
||||
class DetectorSOLOv2(Detector):
|
||||
"""
|
||||
Args:
|
||||
model_dir (str): root path of model.pdiparams, model.pdmodel and infer_cfg.yml
|
||||
device (str): Choose the device you want to run, it can be: CPU/GPU/XPU, default is CPU
|
||||
run_mode (str): mode of running(paddle/trt_fp32/trt_fp16)
|
||||
batch_size (int): size of pre batch in inference
|
||||
trt_min_shape (int): min shape for dynamic shape in trt
|
||||
trt_max_shape (int): max shape for dynamic shape in trt
|
||||
trt_opt_shape (int): opt shape for dynamic shape in trt
|
||||
trt_calib_mode (bool): If the model is produced by TRT offline quantitative
|
||||
calibration, trt_calib_mode need to set True
|
||||
cpu_threads (int): cpu threads
|
||||
enable_mkldnn (bool): whether to open MKLDNN
|
||||
enable_mkldnn_bfloat16 (bool): Whether to turn on mkldnn bfloat16
|
||||
output_dir (str): The path of output
|
||||
threshold (float): The threshold of score for visualization
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
model_dir,
|
||||
device='CPU',
|
||||
run_mode='paddle',
|
||||
batch_size=1,
|
||||
trt_min_shape=1,
|
||||
trt_max_shape=1280,
|
||||
trt_opt_shape=640,
|
||||
trt_calib_mode=False,
|
||||
cpu_threads=1,
|
||||
enable_mkldnn=False,
|
||||
enable_mkldnn_bfloat16=False,
|
||||
output_dir='./',
|
||||
threshold=0.5,
|
||||
use_fd_format=False):
|
||||
super(DetectorSOLOv2, self).__init__(
|
||||
model_dir=model_dir,
|
||||
device=device,
|
||||
run_mode=run_mode,
|
||||
batch_size=batch_size,
|
||||
trt_min_shape=trt_min_shape,
|
||||
trt_max_shape=trt_max_shape,
|
||||
trt_opt_shape=trt_opt_shape,
|
||||
trt_calib_mode=trt_calib_mode,
|
||||
cpu_threads=cpu_threads,
|
||||
enable_mkldnn=enable_mkldnn,
|
||||
enable_mkldnn_bfloat16=enable_mkldnn_bfloat16,
|
||||
output_dir=output_dir,
|
||||
threshold=threshold,
|
||||
use_fd_format=use_fd_format)
|
||||
|
||||
def predict(self, repeats=1, run_benchmark=False):
|
||||
'''
|
||||
Args:
|
||||
repeats (int): repeat number for prediction
|
||||
Returns:
|
||||
result (dict): 'segm': np.ndarray,shape:[N, im_h, im_w]
|
||||
'cate_label': label of segm, shape:[N]
|
||||
'cate_score': confidence score of segm, shape:[N]
|
||||
'''
|
||||
np_segms, np_label, np_score, np_boxes_num = None, None, None, np.array(
|
||||
[0])
|
||||
|
||||
if run_benchmark:
|
||||
for i in range(repeats):
|
||||
self.predictor.run()
|
||||
paddle.device.cuda.synchronize()
|
||||
result = dict(
|
||||
segm=np_segms,
|
||||
label=np_label,
|
||||
score=np_score,
|
||||
boxes_num=np_boxes_num)
|
||||
return result
|
||||
|
||||
for i in range(repeats):
|
||||
self.predictor.run()
|
||||
output_names = self.predictor.get_output_names()
|
||||
np_segms = self.predictor.get_output_handle(output_names[
|
||||
0]).copy_to_cpu()
|
||||
np_boxes_num = self.predictor.get_output_handle(output_names[
|
||||
1]).copy_to_cpu()
|
||||
np_label = self.predictor.get_output_handle(output_names[
|
||||
2]).copy_to_cpu()
|
||||
np_score = self.predictor.get_output_handle(output_names[
|
||||
3]).copy_to_cpu()
|
||||
|
||||
result = dict(
|
||||
segm=np_segms,
|
||||
label=np_label,
|
||||
score=np_score,
|
||||
boxes_num=np_boxes_num)
|
||||
return result
|
||||
|
||||
|
||||
class DetectorPicoDet(Detector):
|
||||
"""
|
||||
Args:
|
||||
model_dir (str): root path of model.pdiparams, model.pdmodel and infer_cfg.yml
|
||||
device (str): Choose the device you want to run, it can be: CPU/GPU/XPU, default is CPU
|
||||
run_mode (str): mode of running(paddle/trt_fp32/trt_fp16)
|
||||
batch_size (int): size of pre batch in inference
|
||||
trt_min_shape (int): min shape for dynamic shape in trt
|
||||
trt_max_shape (int): max shape for dynamic shape in trt
|
||||
trt_opt_shape (int): opt shape for dynamic shape in trt
|
||||
trt_calib_mode (bool): If the model is produced by TRT offline quantitative
|
||||
calibration, trt_calib_mode need to set True
|
||||
cpu_threads (int): cpu threads
|
||||
enable_mkldnn (bool): whether to turn on MKLDNN
|
||||
enable_mkldnn_bfloat16 (bool): whether to turn on MKLDNN_BFLOAT16
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
model_dir,
|
||||
device='CPU',
|
||||
run_mode='paddle',
|
||||
batch_size=1,
|
||||
trt_min_shape=1,
|
||||
trt_max_shape=1280,
|
||||
trt_opt_shape=640,
|
||||
trt_calib_mode=False,
|
||||
cpu_threads=1,
|
||||
enable_mkldnn=False,
|
||||
enable_mkldnn_bfloat16=False,
|
||||
output_dir='./',
|
||||
threshold=0.5,
|
||||
use_fd_format=False):
|
||||
super(DetectorPicoDet, self).__init__(
|
||||
model_dir=model_dir,
|
||||
device=device,
|
||||
run_mode=run_mode,
|
||||
batch_size=batch_size,
|
||||
trt_min_shape=trt_min_shape,
|
||||
trt_max_shape=trt_max_shape,
|
||||
trt_opt_shape=trt_opt_shape,
|
||||
trt_calib_mode=trt_calib_mode,
|
||||
cpu_threads=cpu_threads,
|
||||
enable_mkldnn=enable_mkldnn,
|
||||
enable_mkldnn_bfloat16=enable_mkldnn_bfloat16,
|
||||
output_dir=output_dir,
|
||||
threshold=threshold,
|
||||
use_fd_format=use_fd_format)
|
||||
|
||||
def postprocess(self, inputs, result):
|
||||
# postprocess output of predictor
|
||||
np_score_list = result['boxes']
|
||||
np_boxes_list = result['boxes_num']
|
||||
postprocessor = PicoDetPostProcess(
|
||||
inputs['image'].shape[2:],
|
||||
inputs['im_shape'],
|
||||
inputs['scale_factor'],
|
||||
strides=self.pred_config.fpn_stride,
|
||||
nms_threshold=self.pred_config.nms['nms_threshold'])
|
||||
np_boxes, np_boxes_num = postprocessor(np_score_list, np_boxes_list)
|
||||
result = dict(boxes=np_boxes, boxes_num=np_boxes_num)
|
||||
return result
|
||||
|
||||
def predict(self, repeats=1, run_benchmark=False):
|
||||
'''
|
||||
Args:
|
||||
repeats (int): repeat number for prediction
|
||||
Returns:
|
||||
result (dict): include 'boxes': np.ndarray: shape:[N,6], N: number of box,
|
||||
matix element:[class, score, x_min, y_min, x_max, y_max]
|
||||
'''
|
||||
np_score_list, np_boxes_list = [], []
|
||||
|
||||
if run_benchmark:
|
||||
for i in range(repeats):
|
||||
self.predictor.run()
|
||||
paddle.device.cuda.synchronize()
|
||||
result = dict(boxes=np_score_list, boxes_num=np_boxes_list)
|
||||
return result
|
||||
|
||||
for i in range(repeats):
|
||||
self.predictor.run()
|
||||
np_score_list.clear()
|
||||
np_boxes_list.clear()
|
||||
output_names = self.predictor.get_output_names()
|
||||
num_outs = int(len(output_names) / 2)
|
||||
for out_idx in range(num_outs):
|
||||
np_score_list.append(
|
||||
self.predictor.get_output_handle(output_names[out_idx])
|
||||
.copy_to_cpu())
|
||||
np_boxes_list.append(
|
||||
self.predictor.get_output_handle(output_names[
|
||||
out_idx + num_outs]).copy_to_cpu())
|
||||
result = dict(boxes=np_score_list, boxes_num=np_boxes_list)
|
||||
return result
|
||||
|
||||
|
||||
class DetectorCLRNet(Detector):
|
||||
"""
|
||||
Args:
|
||||
model_dir (str): root path of model.pdiparams, model.pdmodel and infer_cfg.yml
|
||||
device (str): Choose the device you want to run, it can be: CPU/GPU/XPU, default is CPU
|
||||
run_mode (str): mode of running(paddle/trt_fp32/trt_fp16)
|
||||
batch_size (int): size of pre batch in inference
|
||||
trt_min_shape (int): min shape for dynamic shape in trt
|
||||
trt_max_shape (int): max shape for dynamic shape in trt
|
||||
trt_opt_shape (int): opt shape for dynamic shape in trt
|
||||
trt_calib_mode (bool): If the model is produced by TRT offline quantitative
|
||||
calibration, trt_calib_mode need to set True
|
||||
cpu_threads (int): cpu threads
|
||||
enable_mkldnn (bool): whether to turn on MKLDNN
|
||||
enable_mkldnn_bfloat16 (bool): whether to turn on MKLDNN_BFLOAT16
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
model_dir,
|
||||
device='CPU',
|
||||
run_mode='paddle',
|
||||
batch_size=1,
|
||||
trt_min_shape=1,
|
||||
trt_max_shape=1280,
|
||||
trt_opt_shape=640,
|
||||
trt_calib_mode=False,
|
||||
cpu_threads=1,
|
||||
enable_mkldnn=False,
|
||||
enable_mkldnn_bfloat16=False,
|
||||
output_dir='./',
|
||||
threshold=0.5,
|
||||
use_fd_format=False):
|
||||
super(DetectorCLRNet, self).__init__(
|
||||
model_dir=model_dir,
|
||||
device=device,
|
||||
run_mode=run_mode,
|
||||
batch_size=batch_size,
|
||||
trt_min_shape=trt_min_shape,
|
||||
trt_max_shape=trt_max_shape,
|
||||
trt_opt_shape=trt_opt_shape,
|
||||
trt_calib_mode=trt_calib_mode,
|
||||
cpu_threads=cpu_threads,
|
||||
enable_mkldnn=enable_mkldnn,
|
||||
enable_mkldnn_bfloat16=enable_mkldnn_bfloat16,
|
||||
output_dir=output_dir,
|
||||
threshold=threshold,
|
||||
use_fd_format=use_fd_format)
|
||||
|
||||
deploy_file = os.path.join(model_dir, 'infer_cfg.yml')
|
||||
with open(deploy_file) as f:
|
||||
yml_conf = yaml.safe_load(f)
|
||||
self.img_w = yml_conf['img_w']
|
||||
self.ori_img_h = yml_conf['ori_img_h']
|
||||
self.cut_height = yml_conf['cut_height']
|
||||
self.max_lanes = yml_conf['max_lanes']
|
||||
self.nms_thres = yml_conf['nms_thres']
|
||||
self.num_points = yml_conf['num_points']
|
||||
self.conf_threshold = yml_conf['conf_threshold']
|
||||
|
||||
def postprocess(self, inputs, result):
|
||||
# postprocess output of predictor
|
||||
lanes_list = result['lanes']
|
||||
postprocessor = CLRNetPostProcess(
|
||||
img_w=self.img_w,
|
||||
ori_img_h=self.ori_img_h,
|
||||
cut_height=self.cut_height,
|
||||
conf_threshold=self.conf_threshold,
|
||||
nms_thres=self.nms_thres,
|
||||
max_lanes=self.max_lanes,
|
||||
num_points=self.num_points)
|
||||
lanes = postprocessor(lanes_list)
|
||||
result = dict(lanes=lanes)
|
||||
return result
|
||||
|
||||
def predict(self, repeats=1, run_benchmark=False):
|
||||
'''
|
||||
Args:
|
||||
repeats (int): repeat number for prediction
|
||||
Returns:
|
||||
result (dict): include 'boxes': np.ndarray: shape:[N,6], N: number of box,
|
||||
matix element:[class, score, x_min, y_min, x_max, y_max]
|
||||
'''
|
||||
lanes_list = []
|
||||
|
||||
if run_benchmark:
|
||||
for i in range(repeats):
|
||||
self.predictor.run()
|
||||
paddle.device.cuda.synchronize()
|
||||
result = dict(lanes=lanes_list)
|
||||
return result
|
||||
|
||||
for i in range(repeats):
|
||||
# TODO: check the output of predictor
|
||||
self.predictor.run()
|
||||
lanes_list.clear()
|
||||
output_names = self.predictor.get_output_names()
|
||||
num_outs = int(len(output_names) / 2)
|
||||
if num_outs == 0:
|
||||
lanes_list.append([])
|
||||
for out_idx in range(num_outs):
|
||||
lanes_list.append(
|
||||
self.predictor.get_output_handle(output_names[out_idx])
|
||||
.copy_to_cpu())
|
||||
result = dict(lanes=lanes_list)
|
||||
return result
|
||||
|
||||
|
||||
def create_inputs(imgs, im_info):
|
||||
"""generate input for different model type
|
||||
Args:
|
||||
imgs (list(numpy)): list of images (np.ndarray)
|
||||
im_info (list(dict)): list of image info
|
||||
Returns:
|
||||
inputs (dict): input of model
|
||||
"""
|
||||
inputs = {}
|
||||
|
||||
im_shape = []
|
||||
scale_factor = []
|
||||
if len(imgs) == 1:
|
||||
inputs['image'] = np.array((imgs[0], )).astype('float32')
|
||||
inputs['im_shape'] = np.array(
|
||||
(im_info[0]['im_shape'], )).astype('float32')
|
||||
inputs['scale_factor'] = np.array(
|
||||
(im_info[0]['scale_factor'], )).astype('float32')
|
||||
return inputs
|
||||
|
||||
for e in im_info:
|
||||
im_shape.append(np.array((e['im_shape'], )).astype('float32'))
|
||||
scale_factor.append(np.array((e['scale_factor'], )).astype('float32'))
|
||||
|
||||
inputs['im_shape'] = np.concatenate(im_shape, axis=0)
|
||||
inputs['scale_factor'] = np.concatenate(scale_factor, axis=0)
|
||||
|
||||
imgs_shape = [[e.shape[1], e.shape[2]] for e in imgs]
|
||||
max_shape_h = max([e[0] for e in imgs_shape])
|
||||
max_shape_w = max([e[1] for e in imgs_shape])
|
||||
padding_imgs = []
|
||||
for img in imgs:
|
||||
im_c, im_h, im_w = img.shape[:]
|
||||
padding_im = np.zeros(
|
||||
(im_c, max_shape_h, max_shape_w), dtype=np.float32)
|
||||
padding_im[:, :im_h, :im_w] = img
|
||||
padding_imgs.append(padding_im)
|
||||
inputs['image'] = np.stack(padding_imgs, axis=0)
|
||||
return inputs
|
||||
|
||||
|
||||
class PredictConfig():
|
||||
"""set config of preprocess, postprocess and visualize
|
||||
Args:
|
||||
model_dir (str): root path of model.yml
|
||||
"""
|
||||
|
||||
def __init__(self, model_dir, use_fd_format=False):
|
||||
# parsing Yaml config for Preprocess
|
||||
fd_deploy_file = os.path.join(model_dir, 'inference.yml')
|
||||
ppdet_deploy_file = os.path.join(model_dir, 'infer_cfg.yml')
|
||||
if use_fd_format:
|
||||
if not os.path.exists(fd_deploy_file) and os.path.exists(
|
||||
ppdet_deploy_file):
|
||||
raise RuntimeError(
|
||||
"Non-FD format model detected. Please set `use_fd_format` to False."
|
||||
)
|
||||
deploy_file = fd_deploy_file
|
||||
else:
|
||||
if not os.path.exists(ppdet_deploy_file) and os.path.exists(
|
||||
fd_deploy_file):
|
||||
raise RuntimeError(
|
||||
"FD format model detected. Please set `use_fd_format` to False."
|
||||
)
|
||||
deploy_file = ppdet_deploy_file
|
||||
with open(deploy_file) as f:
|
||||
yml_conf = yaml.safe_load(f)
|
||||
self.check_model(yml_conf)
|
||||
self.arch = yml_conf['arch']
|
||||
self.preprocess_infos = yml_conf['Preprocess']
|
||||
self.min_subgraph_size = yml_conf['min_subgraph_size']
|
||||
self.labels = yml_conf['label_list']
|
||||
self.mask = False
|
||||
self.use_dynamic_shape = yml_conf['use_dynamic_shape']
|
||||
if 'mask' in yml_conf:
|
||||
self.mask = yml_conf['mask']
|
||||
self.tracker = None
|
||||
if 'tracker' in yml_conf:
|
||||
self.tracker = yml_conf['tracker']
|
||||
if 'NMS' in yml_conf:
|
||||
self.nms = yml_conf['NMS']
|
||||
if 'fpn_stride' in yml_conf:
|
||||
self.fpn_stride = yml_conf['fpn_stride']
|
||||
if self.arch == 'RCNN' and yml_conf.get('export_onnx', False):
|
||||
print(
|
||||
'The RCNN export model is used for ONNX and it only supports batch_size = 1'
|
||||
)
|
||||
self.print_config()
|
||||
|
||||
def check_model(self, yml_conf):
|
||||
"""
|
||||
Raises:
|
||||
ValueError: loaded model not in supported model type
|
||||
"""
|
||||
for support_model in SUPPORT_MODELS:
|
||||
if support_model in yml_conf['arch']:
|
||||
return True
|
||||
raise ValueError("Unsupported arch: {}, expect {}".format(yml_conf[
|
||||
'arch'], SUPPORT_MODELS))
|
||||
|
||||
def print_config(self):
|
||||
print('----------- Model Configuration -----------')
|
||||
print('%s: %s' % ('Model Arch', self.arch))
|
||||
print('%s: ' % ('Transform Order'))
|
||||
for op_info in self.preprocess_infos:
|
||||
print('--%s: %s' % ('transform op', op_info['type']))
|
||||
print('--------------------------------------------')
|
||||
|
||||
|
||||
def load_predictor(model_dir,
|
||||
arch,
|
||||
run_mode='paddle',
|
||||
batch_size=1,
|
||||
device='CPU',
|
||||
min_subgraph_size=3,
|
||||
use_dynamic_shape=False,
|
||||
trt_min_shape=1,
|
||||
trt_max_shape=1280,
|
||||
trt_opt_shape=640,
|
||||
trt_calib_mode=False,
|
||||
cpu_threads=1,
|
||||
enable_mkldnn=False,
|
||||
enable_mkldnn_bfloat16=False,
|
||||
delete_shuffle_pass=False):
|
||||
"""set AnalysisConfig, generate AnalysisPredictor
|
||||
Args:
|
||||
model_dir (str): root path of __model__ and __params__
|
||||
device (str): Choose the device you want to run, it can be: CPU/GPU/XPU, default is CPU
|
||||
run_mode (str): mode of running(paddle/trt_fp32/trt_fp16/trt_int8)
|
||||
use_dynamic_shape (bool): use dynamic shape or not
|
||||
trt_min_shape (int): min shape for dynamic shape in trt
|
||||
trt_max_shape (int): max shape for dynamic shape in trt
|
||||
trt_opt_shape (int): opt shape for dynamic shape in trt
|
||||
trt_calib_mode (bool): If the model is produced by TRT offline quantitative
|
||||
calibration, trt_calib_mode need to set True
|
||||
delete_shuffle_pass (bool): whether to remove shuffle_channel_detect_pass in TensorRT.
|
||||
Used by action model.
|
||||
Returns:
|
||||
predictor (PaddlePredictor): AnalysisPredictor
|
||||
Raises:
|
||||
ValueError: predict by TensorRT need device == 'GPU'.
|
||||
"""
|
||||
if device != 'GPU' and run_mode != 'paddle':
|
||||
raise ValueError(
|
||||
"Predict by TensorRT mode: {}, expect device=='GPU', but device == {}"
|
||||
.format(run_mode, device))
|
||||
|
||||
if paddle.__version__ >= '3.0.0' or paddle.__version__ == '0.0.0':
|
||||
model_path = model_dir
|
||||
model_prefix = 'model'
|
||||
infer_param = os.path.join(model_dir, 'model.pdiparams')
|
||||
if not os.path.exists(infer_param):
|
||||
if paddle.framework.use_pir_api():
|
||||
infer_model = os.path.join(model_dir, 'inference.pdmodel')
|
||||
else:
|
||||
infer_model = os.path.join(model_dir, 'inference.json')
|
||||
if not os.path.exists(infer_model):
|
||||
raise ValueError(
|
||||
"Cannot find any inference model in dir: {}.".format(model_dir))
|
||||
|
||||
config = Config(model_path, model_prefix)
|
||||
|
||||
else:
|
||||
infer_model = os.path.join(model_dir, 'model.pdmodel')
|
||||
infer_params = os.path.join(model_dir, 'model.pdiparams')
|
||||
if not os.path.exists(infer_model):
|
||||
infer_model = os.path.join(model_dir, 'inference.pdmodel')
|
||||
infer_params = os.path.join(model_dir, 'inference.pdiparams')
|
||||
if not os.path.exists(infer_model):
|
||||
raise ValueError(
|
||||
"Cannot find any inference model in dir: {},".format(model_dir))
|
||||
config = Config(infer_model, infer_params)
|
||||
|
||||
if device == 'GPU':
|
||||
# initial GPU memory(M), device ID
|
||||
config.enable_use_gpu(200, 0)
|
||||
# optimize graph and fuse op
|
||||
config.switch_ir_optim(True)
|
||||
else:
|
||||
config.disable_gpu()
|
||||
config.set_cpu_math_library_num_threads(cpu_threads)
|
||||
if enable_mkldnn:
|
||||
try:
|
||||
# cache 10 different shapes for mkldnn to avoid memory leak
|
||||
config.set_mkldnn_cache_capacity(10)
|
||||
config.enable_mkldnn()
|
||||
if enable_mkldnn_bfloat16:
|
||||
config.enable_mkldnn_bfloat16()
|
||||
except Exception as e:
|
||||
print(
|
||||
"The current environment does not support `mkldnn`, so disable mkldnn."
|
||||
)
|
||||
pass
|
||||
|
||||
precision_map = {
|
||||
'trt_int8': Config.Precision.Int8,
|
||||
'trt_fp32': Config.Precision.Float32,
|
||||
'trt_fp16': Config.Precision.Half
|
||||
}
|
||||
if run_mode in precision_map.keys():
|
||||
config.enable_tensorrt_engine(
|
||||
workspace_size=(1 << 25) * batch_size,
|
||||
max_batch_size=batch_size,
|
||||
min_subgraph_size=min_subgraph_size,
|
||||
precision_mode=precision_map[run_mode],
|
||||
use_static=False,
|
||||
use_calib_mode=trt_calib_mode)
|
||||
if FLAGS.collect_trt_shape_info:
|
||||
config.collect_shape_range_info(FLAGS.tuned_trt_shape_file)
|
||||
elif os.path.exists(FLAGS.tuned_trt_shape_file):
|
||||
print(f'Use dynamic shape file: '
|
||||
f'{FLAGS.tuned_trt_shape_file} for TRT...')
|
||||
config.enable_tuned_tensorrt_dynamic_shape(
|
||||
FLAGS.tuned_trt_shape_file, True)
|
||||
|
||||
if use_dynamic_shape:
|
||||
min_input_shape = {
|
||||
'image': [batch_size, 3, trt_min_shape, trt_min_shape],
|
||||
'scale_factor': [batch_size, 2]
|
||||
}
|
||||
max_input_shape = {
|
||||
'image': [batch_size, 3, trt_max_shape, trt_max_shape],
|
||||
'scale_factor': [batch_size, 2]
|
||||
}
|
||||
opt_input_shape = {
|
||||
'image': [batch_size, 3, trt_opt_shape, trt_opt_shape],
|
||||
'scale_factor': [batch_size, 2]
|
||||
}
|
||||
config.set_trt_dynamic_shape_info(min_input_shape, max_input_shape,
|
||||
opt_input_shape)
|
||||
print('trt set dynamic shape done!')
|
||||
|
||||
# disable print log when predict
|
||||
config.disable_glog_info()
|
||||
# enable shared memory
|
||||
config.enable_memory_optim()
|
||||
# disable feed, fetch OP, needed by zero_copy_run
|
||||
config.switch_use_feed_fetch_ops(False)
|
||||
if delete_shuffle_pass:
|
||||
config.delete_pass("shuffle_channel_detect_pass")
|
||||
predictor = create_predictor(config)
|
||||
return predictor, config
|
||||
|
||||
|
||||
def visualize(image_list, result, labels, output_dir='output/', threshold=0.5):
|
||||
# visualize the predict result
|
||||
if 'lanes' in result:
|
||||
for idx, image_file in enumerate(image_list):
|
||||
lanes = result['lanes'][idx]
|
||||
img = cv2.imread(image_file)
|
||||
out_file = os.path.join(output_dir, os.path.basename(image_file))
|
||||
# hard code
|
||||
lanes = [lane.to_array([], ) for lane in lanes]
|
||||
imshow_lanes(img, lanes, out_file=out_file)
|
||||
return
|
||||
start_idx = 0
|
||||
for idx, image_file in enumerate(image_list):
|
||||
im_bboxes_num = result['boxes_num'][idx]
|
||||
im_results = {}
|
||||
if 'boxes' in result:
|
||||
im_results['boxes'] = result['boxes'][start_idx:start_idx +
|
||||
im_bboxes_num, :]
|
||||
if 'masks' in result:
|
||||
im_results['masks'] = result['masks'][start_idx:start_idx +
|
||||
im_bboxes_num, :]
|
||||
if 'segm' in result:
|
||||
im_results['segm'] = result['segm'][start_idx:start_idx +
|
||||
im_bboxes_num, :]
|
||||
if 'label' in result:
|
||||
im_results['label'] = result['label'][start_idx:start_idx +
|
||||
im_bboxes_num]
|
||||
if 'score' in result:
|
||||
im_results['score'] = result['score'][start_idx:start_idx +
|
||||
im_bboxes_num]
|
||||
|
||||
start_idx += im_bboxes_num
|
||||
im = visualize_box_mask(
|
||||
image_file, im_results, labels, threshold=threshold)
|
||||
img_name = os.path.split(image_file)[-1]
|
||||
if not os.path.exists(output_dir):
|
||||
os.makedirs(output_dir)
|
||||
out_path = os.path.join(output_dir, img_name)
|
||||
im.save(out_path, quality=95)
|
||||
print("save result to: " + out_path)
|
||||
|
||||
|
||||
def print_arguments(args):
|
||||
print('----------- Running Arguments -----------')
|
||||
for arg, value in sorted(vars(args).items()):
|
||||
print('%s: %s' % (arg, value))
|
||||
print('------------------------------------------')
|
||||
|
||||
|
||||
class Pipeline(object):
|
||||
|
||||
def __init__(self, model_dir):
|
||||
if FLAGS.use_fd_format:
|
||||
deploy_file = os.path.join(model_dir, 'inference.yml')
|
||||
else:
|
||||
deploy_file = os.path.join(model_dir, 'infer_cfg.yml')
|
||||
with open(deploy_file) as f:
|
||||
yml_conf = yaml.safe_load(f)
|
||||
arch = yml_conf['arch']
|
||||
detector_func = 'Detector'
|
||||
if arch == 'SOLOv2':
|
||||
detector_func = 'DetectorSOLOv2'
|
||||
elif arch == 'PicoDet':
|
||||
detector_func = 'DetectorPicoDet'
|
||||
elif arch == "CLRNet":
|
||||
detector_func = 'DetectorCLRNet'
|
||||
|
||||
self.detector = eval(detector_func)(
|
||||
model_dir,
|
||||
device=FLAGS.device,
|
||||
run_mode=FLAGS.run_mode,
|
||||
batch_size=FLAGS.batch_size,
|
||||
trt_min_shape=FLAGS.trt_min_shape,
|
||||
trt_max_shape=FLAGS.trt_max_shape,
|
||||
trt_opt_shape=FLAGS.trt_opt_shape,
|
||||
trt_calib_mode=FLAGS.trt_calib_mode,
|
||||
cpu_threads=FLAGS.cpu_threads,
|
||||
enable_mkldnn=FLAGS.enable_mkldnn,
|
||||
enable_mkldnn_bfloat16=FLAGS.enable_mkldnn_bfloat16,
|
||||
threshold=FLAGS.threshold,
|
||||
output_dir=FLAGS.output_dir,
|
||||
use_fd_format=FLAGS.use_fd_format)
|
||||
|
||||
def __call__(self, image_path):
|
||||
if FLAGS.image_dir is None and image_path is not None:
|
||||
assert FLAGS.batch_size == 1, "batch_size should be 1, when image_file is not None"
|
||||
if isinstance(image_path, str):
|
||||
image_path = [image_path]
|
||||
results = self.detector.predict_image(
|
||||
image_path,
|
||||
visual=FLAGS.save_images)
|
||||
return results
|
||||
|
||||
|
||||
paddle.enable_static()
|
||||
parser = argsparser()
|
||||
FLAGS = parser.parse_args()
|
||||
print_arguments(FLAGS)
|
||||
FLAGS.device = 'GPU'
|
||||
FLAGS.save_images = False
|
||||
FLAGS.device = FLAGS.device.upper()
|
||||
assert FLAGS.device in ['CPU', 'GPU', 'XPU', 'NPU', 'MLU', 'GCU'
|
||||
], "device should be CPU, GPU, XPU, MLU, NPU or GCU"
|
||||
assert not FLAGS.use_gpu, "use_gpu has been deprecated, please use --device"
|
||||
|
||||
assert not (
|
||||
FLAGS.enable_mkldnn == False and FLAGS.enable_mkldnn_bfloat16 == True
|
||||
), 'To enable mkldnn bfloat, please turn on both enable_mkldnn and enable_mkldnn_bfloat16'
|
@ -0,0 +1,227 @@
|
||||
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import numpy as np
|
||||
from scipy.special import softmax
|
||||
|
||||
|
||||
def hard_nms(box_scores, iou_threshold, top_k=-1, candidate_size=200):
|
||||
"""
|
||||
Args:
|
||||
box_scores (N, 5): boxes in corner-form and probabilities.
|
||||
iou_threshold: intersection over union threshold.
|
||||
top_k: keep top_k results. If k <= 0, keep all the results.
|
||||
candidate_size: only consider the candidates with the highest scores.
|
||||
Returns:
|
||||
picked: a list of indexes of the kept boxes
|
||||
"""
|
||||
scores = box_scores[:, -1]
|
||||
boxes = box_scores[:, :-1]
|
||||
picked = []
|
||||
indexes = np.argsort(scores)
|
||||
indexes = indexes[-candidate_size:]
|
||||
while len(indexes) > 0:
|
||||
current = indexes[-1]
|
||||
picked.append(current)
|
||||
if 0 < top_k == len(picked) or len(indexes) == 1:
|
||||
break
|
||||
current_box = boxes[current, :]
|
||||
indexes = indexes[:-1]
|
||||
rest_boxes = boxes[indexes, :]
|
||||
iou = iou_of(
|
||||
rest_boxes,
|
||||
np.expand_dims(
|
||||
current_box, axis=0), )
|
||||
indexes = indexes[iou <= iou_threshold]
|
||||
|
||||
return box_scores[picked, :]
|
||||
|
||||
|
||||
def iou_of(boxes0, boxes1, eps=1e-5):
|
||||
"""Return intersection-over-union (Jaccard index) of boxes.
|
||||
Args:
|
||||
boxes0 (N, 4): ground truth boxes.
|
||||
boxes1 (N or 1, 4): predicted boxes.
|
||||
eps: a small number to avoid 0 as denominator.
|
||||
Returns:
|
||||
iou (N): IoU values.
|
||||
"""
|
||||
overlap_left_top = np.maximum(boxes0[..., :2], boxes1[..., :2])
|
||||
overlap_right_bottom = np.minimum(boxes0[..., 2:], boxes1[..., 2:])
|
||||
|
||||
overlap_area = area_of(overlap_left_top, overlap_right_bottom)
|
||||
area0 = area_of(boxes0[..., :2], boxes0[..., 2:])
|
||||
area1 = area_of(boxes1[..., :2], boxes1[..., 2:])
|
||||
return overlap_area / (area0 + area1 - overlap_area + eps)
|
||||
|
||||
|
||||
def area_of(left_top, right_bottom):
|
||||
"""Compute the areas of rectangles given two corners.
|
||||
Args:
|
||||
left_top (N, 2): left top corner.
|
||||
right_bottom (N, 2): right bottom corner.
|
||||
Returns:
|
||||
area (N): return the area.
|
||||
"""
|
||||
hw = np.clip(right_bottom - left_top, 0.0, None)
|
||||
return hw[..., 0] * hw[..., 1]
|
||||
|
||||
|
||||
class PicoDetPostProcess(object):
|
||||
"""
|
||||
Args:
|
||||
input_shape (int): network input image size
|
||||
ori_shape (int): ori image shape of before padding
|
||||
scale_factor (float): scale factor of ori image
|
||||
enable_mkldnn (bool): whether to open MKLDNN
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
input_shape,
|
||||
ori_shape,
|
||||
scale_factor,
|
||||
strides=[8, 16, 32, 64],
|
||||
score_threshold=0.4,
|
||||
nms_threshold=0.5,
|
||||
nms_top_k=1000,
|
||||
keep_top_k=100):
|
||||
self.ori_shape = ori_shape
|
||||
self.input_shape = input_shape
|
||||
self.scale_factor = scale_factor
|
||||
self.strides = strides
|
||||
self.score_threshold = score_threshold
|
||||
self.nms_threshold = nms_threshold
|
||||
self.nms_top_k = nms_top_k
|
||||
self.keep_top_k = keep_top_k
|
||||
|
||||
def warp_boxes(self, boxes, ori_shape):
|
||||
"""Apply transform to boxes
|
||||
"""
|
||||
width, height = ori_shape[1], ori_shape[0]
|
||||
n = len(boxes)
|
||||
if n:
|
||||
# warp points
|
||||
xy = np.ones((n * 4, 3))
|
||||
xy[:, :2] = boxes[:, [0, 1, 2, 3, 0, 3, 2, 1]].reshape(
|
||||
n * 4, 2) # x1y1, x2y2, x1y2, x2y1
|
||||
# xy = xy @ M.T # transform
|
||||
xy = (xy[:, :2] / xy[:, 2:3]).reshape(n, 8) # rescale
|
||||
# create new boxes
|
||||
x = xy[:, [0, 2, 4, 6]]
|
||||
y = xy[:, [1, 3, 5, 7]]
|
||||
xy = np.concatenate(
|
||||
(x.min(1), y.min(1), x.max(1), y.max(1))).reshape(4, n).T
|
||||
# clip boxes
|
||||
xy[:, [0, 2]] = xy[:, [0, 2]].clip(0, width)
|
||||
xy[:, [1, 3]] = xy[:, [1, 3]].clip(0, height)
|
||||
return xy.astype(np.float32)
|
||||
else:
|
||||
return boxes
|
||||
|
||||
def __call__(self, scores, raw_boxes):
|
||||
batch_size = raw_boxes[0].shape[0]
|
||||
reg_max = int(raw_boxes[0].shape[-1] / 4 - 1)
|
||||
out_boxes_num = []
|
||||
out_boxes_list = []
|
||||
for batch_id in range(batch_size):
|
||||
# generate centers
|
||||
decode_boxes = []
|
||||
select_scores = []
|
||||
for stride, box_distribute, score in zip(self.strides, raw_boxes,
|
||||
scores):
|
||||
box_distribute = box_distribute[batch_id]
|
||||
score = score[batch_id]
|
||||
# centers
|
||||
fm_h = self.input_shape[0] / stride
|
||||
fm_w = self.input_shape[1] / stride
|
||||
h_range = np.arange(fm_h)
|
||||
w_range = np.arange(fm_w)
|
||||
ww, hh = np.meshgrid(w_range, h_range)
|
||||
ct_row = (hh.flatten() + 0.5) * stride
|
||||
ct_col = (ww.flatten() + 0.5) * stride
|
||||
center = np.stack((ct_col, ct_row, ct_col, ct_row), axis=1)
|
||||
|
||||
# box distribution to distance
|
||||
reg_range = np.arange(reg_max + 1)
|
||||
box_distance = box_distribute.reshape((-1, reg_max + 1))
|
||||
box_distance = softmax(box_distance, axis=1)
|
||||
box_distance = box_distance * np.expand_dims(reg_range, axis=0)
|
||||
box_distance = np.sum(box_distance, axis=1).reshape((-1, 4))
|
||||
box_distance = box_distance * stride
|
||||
|
||||
# top K candidate
|
||||
topk_idx = np.argsort(score.max(axis=1))[::-1]
|
||||
topk_idx = topk_idx[:self.nms_top_k]
|
||||
center = center[topk_idx]
|
||||
score = score[topk_idx]
|
||||
box_distance = box_distance[topk_idx]
|
||||
|
||||
# decode box
|
||||
decode_box = center + [-1, -1, 1, 1] * box_distance
|
||||
|
||||
select_scores.append(score)
|
||||
decode_boxes.append(decode_box)
|
||||
|
||||
# nms
|
||||
bboxes = np.concatenate(decode_boxes, axis=0)
|
||||
confidences = np.concatenate(select_scores, axis=0)
|
||||
picked_box_probs = []
|
||||
picked_labels = []
|
||||
for class_index in range(0, confidences.shape[1]):
|
||||
probs = confidences[:, class_index]
|
||||
mask = probs > self.score_threshold
|
||||
probs = probs[mask]
|
||||
if probs.shape[0] == 0:
|
||||
continue
|
||||
subset_boxes = bboxes[mask, :]
|
||||
box_probs = np.concatenate(
|
||||
[subset_boxes, probs.reshape(-1, 1)], axis=1)
|
||||
box_probs = hard_nms(
|
||||
box_probs,
|
||||
iou_threshold=self.nms_threshold,
|
||||
top_k=self.keep_top_k, )
|
||||
picked_box_probs.append(box_probs)
|
||||
picked_labels.extend([class_index] * box_probs.shape[0])
|
||||
|
||||
if len(picked_box_probs) == 0:
|
||||
out_boxes_list.append(np.empty((0, 4)))
|
||||
out_boxes_num.append(0)
|
||||
|
||||
else:
|
||||
picked_box_probs = np.concatenate(picked_box_probs)
|
||||
|
||||
# resize output boxes
|
||||
picked_box_probs[:, :4] = self.warp_boxes(
|
||||
picked_box_probs[:, :4], self.ori_shape[batch_id])
|
||||
im_scale = np.concatenate([
|
||||
self.scale_factor[batch_id][::-1],
|
||||
self.scale_factor[batch_id][::-1]
|
||||
])
|
||||
picked_box_probs[:, :4] /= im_scale
|
||||
# clas score box
|
||||
out_boxes_list.append(
|
||||
np.concatenate(
|
||||
[
|
||||
np.expand_dims(
|
||||
np.array(picked_labels),
|
||||
axis=-1), np.expand_dims(
|
||||
picked_box_probs[:, 4], axis=-1),
|
||||
picked_box_probs[:, :4]
|
||||
],
|
||||
axis=1))
|
||||
out_boxes_num.append(len(picked_labels))
|
||||
|
||||
out_boxes_list = np.concatenate(out_boxes_list, axis=0)
|
||||
out_boxes_num = np.asarray(out_boxes_num).astype(np.int32)
|
||||
return out_boxes_list, out_boxes_num
|
@ -0,0 +1,549 @@
|
||||
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
import imgaug.augmenters as iaa
|
||||
from keypoint_preprocess import get_affine_transform
|
||||
from PIL import Image
|
||||
|
||||
|
||||
def decode_image(im_file, im_info):
|
||||
"""read rgb image
|
||||
Args:
|
||||
im_file (str|np.ndarray): input can be image path or np.ndarray
|
||||
im_info (dict): info of image
|
||||
Returns:
|
||||
im (np.ndarray): processed image (np.ndarray)
|
||||
im_info (dict): info of processed image
|
||||
"""
|
||||
if isinstance(im_file, str):
|
||||
with open(im_file, 'rb') as f:
|
||||
im_read = f.read()
|
||||
data = np.frombuffer(im_read, dtype='uint8')
|
||||
im = cv2.imdecode(data, 1) # BGR mode, but need RGB mode
|
||||
im = cv2.cvtColor(im, cv2.COLOR_BGR2RGB)
|
||||
else:
|
||||
im = im_file
|
||||
im_info['im_shape'] = np.array(im.shape[:2], dtype=np.float32)
|
||||
im_info['scale_factor'] = np.array([1., 1.], dtype=np.float32)
|
||||
return im, im_info
|
||||
|
||||
|
||||
class Resize_Mult32(object):
|
||||
"""resize image by target_size and max_size
|
||||
Args:
|
||||
target_size (int): the target size of image
|
||||
keep_ratio (bool): whether keep_ratio or not, default true
|
||||
interp (int): method of resize
|
||||
"""
|
||||
|
||||
def __init__(self, limit_side_len, limit_type, interp=cv2.INTER_LINEAR):
|
||||
self.limit_side_len = limit_side_len
|
||||
self.limit_type = limit_type
|
||||
self.interp = interp
|
||||
|
||||
def __call__(self, im, im_info):
|
||||
"""
|
||||
Args:
|
||||
im (np.ndarray): image (np.ndarray)
|
||||
im_info (dict): info of image
|
||||
Returns:
|
||||
im (np.ndarray): processed image (np.ndarray)
|
||||
im_info (dict): info of processed image
|
||||
"""
|
||||
im_channel = im.shape[2]
|
||||
im_scale_y, im_scale_x = self.generate_scale(im)
|
||||
im = cv2.resize(
|
||||
im,
|
||||
None,
|
||||
None,
|
||||
fx=im_scale_x,
|
||||
fy=im_scale_y,
|
||||
interpolation=self.interp)
|
||||
im_info['im_shape'] = np.array(im.shape[:2]).astype('float32')
|
||||
im_info['scale_factor'] = np.array(
|
||||
[im_scale_y, im_scale_x]).astype('float32')
|
||||
return im, im_info
|
||||
|
||||
def generate_scale(self, img):
|
||||
"""
|
||||
Args:
|
||||
img (np.ndarray): image (np.ndarray)
|
||||
Returns:
|
||||
im_scale_x: the resize ratio of X
|
||||
im_scale_y: the resize ratio of Y
|
||||
"""
|
||||
limit_side_len = self.limit_side_len
|
||||
h, w, c = img.shape
|
||||
|
||||
# limit the max side
|
||||
if self.limit_type == 'max':
|
||||
if h > w:
|
||||
ratio = float(limit_side_len) / h
|
||||
else:
|
||||
ratio = float(limit_side_len) / w
|
||||
elif self.limit_type == 'min':
|
||||
if h < w:
|
||||
ratio = float(limit_side_len) / h
|
||||
else:
|
||||
ratio = float(limit_side_len) / w
|
||||
elif self.limit_type == 'resize_long':
|
||||
ratio = float(limit_side_len) / max(h, w)
|
||||
else:
|
||||
raise Exception('not support limit type, image ')
|
||||
resize_h = int(h * ratio)
|
||||
resize_w = int(w * ratio)
|
||||
|
||||
resize_h = max(int(round(resize_h / 32) * 32), 32)
|
||||
resize_w = max(int(round(resize_w / 32) * 32), 32)
|
||||
|
||||
im_scale_y = resize_h / float(h)
|
||||
im_scale_x = resize_w / float(w)
|
||||
return im_scale_y, im_scale_x
|
||||
|
||||
|
||||
class Resize(object):
|
||||
"""resize image by target_size and max_size
|
||||
Args:
|
||||
target_size (int): the target size of image
|
||||
keep_ratio (bool): whether keep_ratio or not, default true
|
||||
interp (int): method of resize
|
||||
"""
|
||||
|
||||
def __init__(self, target_size, keep_ratio=True, interp=cv2.INTER_LINEAR):
|
||||
if isinstance(target_size, int):
|
||||
target_size = [target_size, target_size]
|
||||
self.target_size = target_size
|
||||
self.keep_ratio = keep_ratio
|
||||
self.interp = interp
|
||||
|
||||
def __call__(self, im, im_info):
|
||||
"""
|
||||
Args:
|
||||
im (np.ndarray): image (np.ndarray)
|
||||
im_info (dict): info of image
|
||||
Returns:
|
||||
im (np.ndarray): processed image (np.ndarray)
|
||||
im_info (dict): info of processed image
|
||||
"""
|
||||
assert len(self.target_size) == 2
|
||||
assert self.target_size[0] > 0 and self.target_size[1] > 0
|
||||
im_channel = im.shape[2]
|
||||
im_scale_y, im_scale_x = self.generate_scale(im)
|
||||
im = cv2.resize(
|
||||
im,
|
||||
None,
|
||||
None,
|
||||
fx=im_scale_x,
|
||||
fy=im_scale_y,
|
||||
interpolation=self.interp)
|
||||
im_info['im_shape'] = np.array(im.shape[:2]).astype('float32')
|
||||
im_info['scale_factor'] = np.array(
|
||||
[im_scale_y, im_scale_x]).astype('float32')
|
||||
return im, im_info
|
||||
|
||||
def generate_scale(self, im):
|
||||
"""
|
||||
Args:
|
||||
im (np.ndarray): image (np.ndarray)
|
||||
Returns:
|
||||
im_scale_x: the resize ratio of X
|
||||
im_scale_y: the resize ratio of Y
|
||||
"""
|
||||
origin_shape = im.shape[:2]
|
||||
im_c = im.shape[2]
|
||||
if self.keep_ratio:
|
||||
im_size_min = np.min(origin_shape)
|
||||
im_size_max = np.max(origin_shape)
|
||||
target_size_min = np.min(self.target_size)
|
||||
target_size_max = np.max(self.target_size)
|
||||
im_scale = float(target_size_min) / float(im_size_min)
|
||||
if np.round(im_scale * im_size_max) > target_size_max:
|
||||
im_scale = float(target_size_max) / float(im_size_max)
|
||||
im_scale_x = im_scale
|
||||
im_scale_y = im_scale
|
||||
else:
|
||||
resize_h, resize_w = self.target_size
|
||||
im_scale_y = resize_h / float(origin_shape[0])
|
||||
im_scale_x = resize_w / float(origin_shape[1])
|
||||
return im_scale_y, im_scale_x
|
||||
|
||||
|
||||
class ShortSizeScale(object):
|
||||
"""
|
||||
Scale images by short size.
|
||||
Args:
|
||||
short_size(float | int): Short size of an image will be scaled to the short_size.
|
||||
fixed_ratio(bool): Set whether to zoom according to a fixed ratio. default: True
|
||||
do_round(bool): Whether to round up when calculating the zoom ratio. default: False
|
||||
backend(str): Choose pillow or cv2 as the graphics processing backend. default: 'pillow'
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
short_size,
|
||||
fixed_ratio=True,
|
||||
keep_ratio=None,
|
||||
do_round=False,
|
||||
backend='pillow'):
|
||||
self.short_size = short_size
|
||||
assert (fixed_ratio and not keep_ratio) or (
|
||||
not fixed_ratio
|
||||
), "fixed_ratio and keep_ratio cannot be true at the same time"
|
||||
self.fixed_ratio = fixed_ratio
|
||||
self.keep_ratio = keep_ratio
|
||||
self.do_round = do_round
|
||||
|
||||
assert backend in [
|
||||
'pillow', 'cv2'
|
||||
], "Scale's backend must be pillow or cv2, but get {backend}"
|
||||
|
||||
self.backend = backend
|
||||
|
||||
def __call__(self, img):
|
||||
"""
|
||||
Performs resize operations.
|
||||
Args:
|
||||
img (PIL.Image): a PIL.Image.
|
||||
return:
|
||||
resized_img: a PIL.Image after scaling.
|
||||
"""
|
||||
|
||||
result_img = None
|
||||
|
||||
if isinstance(img, np.ndarray):
|
||||
h, w, _ = img.shape
|
||||
elif isinstance(img, Image.Image):
|
||||
w, h = img.size
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
if w <= h:
|
||||
ow = self.short_size
|
||||
if self.fixed_ratio: # default is True
|
||||
oh = int(self.short_size * 4.0 / 3.0)
|
||||
elif not self.keep_ratio: # no
|
||||
oh = self.short_size
|
||||
else:
|
||||
scale_factor = self.short_size / w
|
||||
oh = int(h * float(scale_factor) +
|
||||
0.5) if self.do_round else int(h * self.short_size / w)
|
||||
ow = int(w * float(scale_factor) +
|
||||
0.5) if self.do_round else int(w * self.short_size / h)
|
||||
else:
|
||||
oh = self.short_size
|
||||
if self.fixed_ratio:
|
||||
ow = int(self.short_size * 4.0 / 3.0)
|
||||
elif not self.keep_ratio: # no
|
||||
ow = self.short_size
|
||||
else:
|
||||
scale_factor = self.short_size / h
|
||||
oh = int(h * float(scale_factor) +
|
||||
0.5) if self.do_round else int(h * self.short_size / w)
|
||||
ow = int(w * float(scale_factor) +
|
||||
0.5) if self.do_round else int(w * self.short_size / h)
|
||||
|
||||
if type(img) == np.ndarray:
|
||||
img = Image.fromarray(img, mode='RGB')
|
||||
|
||||
if self.backend == 'pillow':
|
||||
result_img = img.resize((ow, oh), Image.BILINEAR)
|
||||
elif self.backend == 'cv2' and (self.keep_ratio is not None):
|
||||
result_img = cv2.resize(
|
||||
img, (ow, oh), interpolation=cv2.INTER_LINEAR)
|
||||
else:
|
||||
result_img = Image.fromarray(
|
||||
cv2.resize(
|
||||
np.asarray(img), (ow, oh), interpolation=cv2.INTER_LINEAR))
|
||||
|
||||
return result_img
|
||||
|
||||
|
||||
class NormalizeImage(object):
|
||||
"""normalize image
|
||||
Args:
|
||||
mean (list): im - mean
|
||||
std (list): im / std
|
||||
is_scale (bool): whether need im / 255
|
||||
norm_type (str): type in ['mean_std', 'none']
|
||||
"""
|
||||
|
||||
def __init__(self, mean, std, is_scale=True, norm_type='mean_std'):
|
||||
self.mean = mean
|
||||
self.std = std
|
||||
self.is_scale = is_scale
|
||||
self.norm_type = norm_type
|
||||
|
||||
def __call__(self, im, im_info):
|
||||
"""
|
||||
Args:
|
||||
im (np.ndarray): image (np.ndarray)
|
||||
im_info (dict): info of image
|
||||
Returns:
|
||||
im (np.ndarray): processed image (np.ndarray)
|
||||
im_info (dict): info of processed image
|
||||
"""
|
||||
im = im.astype(np.float32, copy=False)
|
||||
if self.is_scale:
|
||||
scale = 1.0 / 255.0
|
||||
im *= scale
|
||||
|
||||
if self.norm_type == 'mean_std':
|
||||
mean = np.array(self.mean)[np.newaxis, np.newaxis, :]
|
||||
std = np.array(self.std)[np.newaxis, np.newaxis, :]
|
||||
im -= mean
|
||||
im /= std
|
||||
return im, im_info
|
||||
|
||||
|
||||
class Permute(object):
|
||||
"""permute image
|
||||
Args:
|
||||
to_bgr (bool): whether convert RGB to BGR
|
||||
channel_first (bool): whether convert HWC to CHW
|
||||
"""
|
||||
|
||||
def __init__(self, ):
|
||||
super(Permute, self).__init__()
|
||||
|
||||
def __call__(self, im, im_info):
|
||||
"""
|
||||
Args:
|
||||
im (np.ndarray): image (np.ndarray)
|
||||
im_info (dict): info of image
|
||||
Returns:
|
||||
im (np.ndarray): processed image (np.ndarray)
|
||||
im_info (dict): info of processed image
|
||||
"""
|
||||
im = im.transpose((2, 0, 1)).copy()
|
||||
return im, im_info
|
||||
|
||||
|
||||
class PadStride(object):
|
||||
""" padding image for model with FPN, instead PadBatch(pad_to_stride) in original config
|
||||
Args:
|
||||
stride (bool): model with FPN need image shape % stride == 0
|
||||
"""
|
||||
|
||||
def __init__(self, stride=0):
|
||||
self.coarsest_stride = stride
|
||||
|
||||
def __call__(self, im, im_info):
|
||||
"""
|
||||
Args:
|
||||
im (np.ndarray): image (np.ndarray)
|
||||
im_info (dict): info of image
|
||||
Returns:
|
||||
im (np.ndarray): processed image (np.ndarray)
|
||||
im_info (dict): info of processed image
|
||||
"""
|
||||
coarsest_stride = self.coarsest_stride
|
||||
if coarsest_stride <= 0:
|
||||
return im, im_info
|
||||
im_c, im_h, im_w = im.shape
|
||||
pad_h = int(np.ceil(float(im_h) / coarsest_stride) * coarsest_stride)
|
||||
pad_w = int(np.ceil(float(im_w) / coarsest_stride) * coarsest_stride)
|
||||
padding_im = np.zeros((im_c, pad_h, pad_w), dtype=np.float32)
|
||||
padding_im[:, :im_h, :im_w] = im
|
||||
return padding_im, im_info
|
||||
|
||||
|
||||
class LetterBoxResize(object):
|
||||
def __init__(self, target_size):
|
||||
"""
|
||||
Resize image to target size, convert normalized xywh to pixel xyxy
|
||||
format ([x_center, y_center, width, height] -> [x0, y0, x1, y1]).
|
||||
Args:
|
||||
target_size (int|list): image target size.
|
||||
"""
|
||||
super(LetterBoxResize, self).__init__()
|
||||
if isinstance(target_size, int):
|
||||
target_size = [target_size, target_size]
|
||||
self.target_size = target_size
|
||||
|
||||
def letterbox(self, img, height, width, color=(127.5, 127.5, 127.5)):
|
||||
# letterbox: resize a rectangular image to a padded rectangular
|
||||
shape = img.shape[:2] # [height, width]
|
||||
ratio_h = float(height) / shape[0]
|
||||
ratio_w = float(width) / shape[1]
|
||||
ratio = min(ratio_h, ratio_w)
|
||||
new_shape = (round(shape[1] * ratio),
|
||||
round(shape[0] * ratio)) # [width, height]
|
||||
padw = (width - new_shape[0]) / 2
|
||||
padh = (height - new_shape[1]) / 2
|
||||
top, bottom = round(padh - 0.1), round(padh + 0.1)
|
||||
left, right = round(padw - 0.1), round(padw + 0.1)
|
||||
|
||||
img = cv2.resize(
|
||||
img, new_shape, interpolation=cv2.INTER_AREA) # resized, no border
|
||||
img = cv2.copyMakeBorder(
|
||||
img, top, bottom, left, right, cv2.BORDER_CONSTANT,
|
||||
value=color) # padded rectangular
|
||||
return img, ratio, padw, padh
|
||||
|
||||
def __call__(self, im, im_info):
|
||||
"""
|
||||
Args:
|
||||
im (np.ndarray): image (np.ndarray)
|
||||
im_info (dict): info of image
|
||||
Returns:
|
||||
im (np.ndarray): processed image (np.ndarray)
|
||||
im_info (dict): info of processed image
|
||||
"""
|
||||
assert len(self.target_size) == 2
|
||||
assert self.target_size[0] > 0 and self.target_size[1] > 0
|
||||
height, width = self.target_size
|
||||
h, w = im.shape[:2]
|
||||
im, ratio, padw, padh = self.letterbox(im, height=height, width=width)
|
||||
|
||||
new_shape = [round(h * ratio), round(w * ratio)]
|
||||
im_info['im_shape'] = np.array(new_shape, dtype=np.float32)
|
||||
im_info['scale_factor'] = np.array([ratio, ratio], dtype=np.float32)
|
||||
return im, im_info
|
||||
|
||||
|
||||
class Pad(object):
|
||||
def __init__(self, size, fill_value=[114.0, 114.0, 114.0]):
|
||||
"""
|
||||
Pad image to a specified size.
|
||||
Args:
|
||||
size (list[int]): image target size
|
||||
fill_value (list[float]): rgb value of pad area, default (114.0, 114.0, 114.0)
|
||||
"""
|
||||
super(Pad, self).__init__()
|
||||
if isinstance(size, int):
|
||||
size = [size, size]
|
||||
self.size = size
|
||||
self.fill_value = fill_value
|
||||
|
||||
def __call__(self, im, im_info):
|
||||
im_h, im_w = im.shape[:2]
|
||||
h, w = self.size
|
||||
if h == im_h and w == im_w:
|
||||
im = im.astype(np.float32)
|
||||
return im, im_info
|
||||
|
||||
canvas = np.ones((h, w, 3), dtype=np.float32)
|
||||
canvas *= np.array(self.fill_value, dtype=np.float32)
|
||||
canvas[0:im_h, 0:im_w, :] = im.astype(np.float32)
|
||||
im = canvas
|
||||
return im, im_info
|
||||
|
||||
|
||||
class WarpAffine(object):
|
||||
"""Warp affine the image
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
keep_res=False,
|
||||
pad=31,
|
||||
input_h=512,
|
||||
input_w=512,
|
||||
scale=0.4,
|
||||
shift=0.1,
|
||||
down_ratio=4):
|
||||
self.keep_res = keep_res
|
||||
self.pad = pad
|
||||
self.input_h = input_h
|
||||
self.input_w = input_w
|
||||
self.scale = scale
|
||||
self.shift = shift
|
||||
self.down_ratio = down_ratio
|
||||
|
||||
def __call__(self, im, im_info):
|
||||
"""
|
||||
Args:
|
||||
im (np.ndarray): image (np.ndarray)
|
||||
im_info (dict): info of image
|
||||
Returns:
|
||||
im (np.ndarray): processed image (np.ndarray)
|
||||
im_info (dict): info of processed image
|
||||
"""
|
||||
img = cv2.cvtColor(im, cv2.COLOR_RGB2BGR)
|
||||
|
||||
h, w = img.shape[:2]
|
||||
|
||||
if self.keep_res:
|
||||
# True in detection eval/infer
|
||||
input_h = (h | self.pad) + 1
|
||||
input_w = (w | self.pad) + 1
|
||||
s = np.array([input_w, input_h], dtype=np.float32)
|
||||
c = np.array([w // 2, h // 2], dtype=np.float32)
|
||||
|
||||
else:
|
||||
# False in centertrack eval_mot/eval_mot
|
||||
s = max(h, w) * 1.0
|
||||
input_h, input_w = self.input_h, self.input_w
|
||||
c = np.array([w / 2., h / 2.], dtype=np.float32)
|
||||
|
||||
trans_input = get_affine_transform(c, s, 0, [input_w, input_h])
|
||||
img = cv2.resize(img, (w, h))
|
||||
inp = cv2.warpAffine(
|
||||
img, trans_input, (input_w, input_h), flags=cv2.INTER_LINEAR)
|
||||
|
||||
if not self.keep_res:
|
||||
out_h = input_h // self.down_ratio
|
||||
out_w = input_w // self.down_ratio
|
||||
trans_output = get_affine_transform(c, s, 0, [out_w, out_h])
|
||||
|
||||
im_info.update({
|
||||
'center': c,
|
||||
'scale': s,
|
||||
'out_height': out_h,
|
||||
'out_width': out_w,
|
||||
'inp_height': input_h,
|
||||
'inp_width': input_w,
|
||||
'trans_input': trans_input,
|
||||
'trans_output': trans_output,
|
||||
})
|
||||
return inp, im_info
|
||||
|
||||
|
||||
class CULaneResize(object):
|
||||
def __init__(self, img_h, img_w, cut_height, prob=0.5):
|
||||
super(CULaneResize, self).__init__()
|
||||
self.img_h = img_h
|
||||
self.img_w = img_w
|
||||
self.cut_height = cut_height
|
||||
self.prob = prob
|
||||
|
||||
def __call__(self, im, im_info):
|
||||
# cut
|
||||
im = im[self.cut_height:, :, :]
|
||||
# resize
|
||||
transform = iaa.Sometimes(self.prob,
|
||||
iaa.Resize({
|
||||
"height": self.img_h,
|
||||
"width": self.img_w
|
||||
}))
|
||||
im = transform(image=im.copy().astype(np.uint8))
|
||||
|
||||
im = im.astype(np.float32) / 255.
|
||||
# check transpose is need whether the func decode_image is equal to CULaneDataSet cv.imread
|
||||
im = im.transpose(2, 0, 1)
|
||||
|
||||
return im, im_info
|
||||
|
||||
|
||||
def preprocess(im, preprocess_ops):
|
||||
# process image by preprocess_ops
|
||||
im_info = {
|
||||
'scale_factor': np.array(
|
||||
[1., 1.], dtype=np.float32),
|
||||
'im_shape': None,
|
||||
}
|
||||
im, im_info = decode_image(im, im_info)
|
||||
for operator in preprocess_ops:
|
||||
im, im_info = operator(im, im_info)
|
||||
return im, im_info
|
@ -0,0 +1,6 @@
|
||||
export FLAGS_enable_pir_api=0
|
||||
|
||||
python3 pdf_detection.py \
|
||||
--model_dir=/mnt/research/PaddleOCR/PaddleDetection/inference_model/picodet_lcnet_x1_0_fgd_layout_cdla_infer \
|
||||
--image_file=/mnt/research/PaddleOCR/demo-75-images/12.jpg \
|
||||
--device=GPU
|
@ -0,0 +1,649 @@
|
||||
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import division
|
||||
|
||||
import os
|
||||
import cv2
|
||||
import math
|
||||
import numpy as np
|
||||
import PIL
|
||||
from PIL import Image, ImageDraw, ImageFile
|
||||
ImageFile.LOAD_TRUNCATED_IMAGES = True
|
||||
|
||||
def imagedraw_textsize_c(draw, text):
|
||||
if int(PIL.__version__.split('.')[0]) < 10:
|
||||
tw, th = draw.textsize(text)
|
||||
else:
|
||||
left, top, right, bottom = draw.textbbox((0, 0), text)
|
||||
tw, th = right - left, bottom - top
|
||||
|
||||
return tw, th
|
||||
|
||||
|
||||
def visualize_box_mask(im, results, labels, threshold=0.5):
|
||||
"""
|
||||
Args:
|
||||
im (str/np.ndarray): path of image/np.ndarray read by cv2
|
||||
results (dict): include 'boxes': np.ndarray: shape:[N,6], N: number of box,
|
||||
matix element:[class, score, x_min, y_min, x_max, y_max]
|
||||
MaskRCNN's results include 'masks': np.ndarray:
|
||||
shape:[N, im_h, im_w]
|
||||
labels (list): labels:['class1', ..., 'classn']
|
||||
threshold (float): Threshold of score.
|
||||
Returns:
|
||||
im (PIL.Image.Image): visualized image
|
||||
"""
|
||||
if isinstance(im, str):
|
||||
im = Image.open(im).convert('RGB')
|
||||
elif isinstance(im, np.ndarray):
|
||||
im = Image.fromarray(im)
|
||||
if 'masks' in results and 'boxes' in results and len(results['boxes']) > 0:
|
||||
im = draw_mask(
|
||||
im, results['boxes'], results['masks'], labels, threshold=threshold)
|
||||
if 'boxes' in results and len(results['boxes']) > 0:
|
||||
im = draw_box(im, results['boxes'], labels, threshold=threshold)
|
||||
if 'segm' in results:
|
||||
im = draw_segm(
|
||||
im,
|
||||
results['segm'],
|
||||
results['label'],
|
||||
results['score'],
|
||||
labels,
|
||||
threshold=threshold)
|
||||
return im
|
||||
|
||||
|
||||
def get_color_map_list(num_classes):
|
||||
"""
|
||||
Args:
|
||||
num_classes (int): number of class
|
||||
Returns:
|
||||
color_map (list): RGB color list
|
||||
"""
|
||||
color_map = num_classes * [0, 0, 0]
|
||||
for i in range(0, num_classes):
|
||||
j = 0
|
||||
lab = i
|
||||
while lab:
|
||||
color_map[i * 3] |= (((lab >> 0) & 1) << (7 - j))
|
||||
color_map[i * 3 + 1] |= (((lab >> 1) & 1) << (7 - j))
|
||||
color_map[i * 3 + 2] |= (((lab >> 2) & 1) << (7 - j))
|
||||
j += 1
|
||||
lab >>= 3
|
||||
color_map = [color_map[i:i + 3] for i in range(0, len(color_map), 3)]
|
||||
return color_map
|
||||
|
||||
|
||||
def draw_mask(im, np_boxes, np_masks, labels, threshold=0.5):
|
||||
"""
|
||||
Args:
|
||||
im (PIL.Image.Image): PIL image
|
||||
np_boxes (np.ndarray): shape:[N,6], N: number of box,
|
||||
matix element:[class, score, x_min, y_min, x_max, y_max]
|
||||
np_masks (np.ndarray): shape:[N, im_h, im_w]
|
||||
labels (list): labels:['class1', ..., 'classn']
|
||||
threshold (float): threshold of mask
|
||||
Returns:
|
||||
im (PIL.Image.Image): visualized image
|
||||
"""
|
||||
color_list = get_color_map_list(len(labels))
|
||||
w_ratio = 0.4
|
||||
alpha = 0.7
|
||||
im = np.array(im).astype('float32')
|
||||
clsid2color = {}
|
||||
expect_boxes = (np_boxes[:, 1] > threshold) & (np_boxes[:, 0] > -1)
|
||||
np_boxes = np_boxes[expect_boxes, :]
|
||||
np_masks = np_masks[expect_boxes, :, :]
|
||||
im_h, im_w = im.shape[:2]
|
||||
np_masks = np_masks[:, :im_h, :im_w]
|
||||
for i in range(len(np_masks)):
|
||||
clsid, score = int(np_boxes[i][0]), np_boxes[i][1]
|
||||
mask = np_masks[i]
|
||||
if clsid not in clsid2color:
|
||||
clsid2color[clsid] = color_list[clsid]
|
||||
color_mask = clsid2color[clsid]
|
||||
for c in range(3):
|
||||
color_mask[c] = color_mask[c] * (1 - w_ratio) + w_ratio * 255
|
||||
idx = np.nonzero(mask)
|
||||
color_mask = np.array(color_mask)
|
||||
im[idx[0], idx[1], :] *= 1.0 - alpha
|
||||
im[idx[0], idx[1], :] += alpha * color_mask
|
||||
return Image.fromarray(im.astype('uint8'))
|
||||
|
||||
|
||||
def draw_box(im, np_boxes, labels, threshold=0.5):
|
||||
"""
|
||||
Args:
|
||||
im (PIL.Image.Image): PIL image
|
||||
np_boxes (np.ndarray): shape:[N,6], N: number of box,
|
||||
matix element:[class, score, x_min, y_min, x_max, y_max]
|
||||
labels (list): labels:['class1', ..., 'classn']
|
||||
threshold (float): threshold of box
|
||||
Returns:
|
||||
im (PIL.Image.Image): visualized image
|
||||
"""
|
||||
draw_thickness = min(im.size) // 320
|
||||
draw = ImageDraw.Draw(im)
|
||||
clsid2color = {}
|
||||
color_list = get_color_map_list(len(labels))
|
||||
expect_boxes = (np_boxes[:, 1] > threshold) & (np_boxes[:, 0] > -1)
|
||||
np_boxes = np_boxes[expect_boxes, :]
|
||||
|
||||
for dt in np_boxes:
|
||||
clsid, bbox, score = int(dt[0]), dt[2:], dt[1]
|
||||
if clsid not in clsid2color:
|
||||
clsid2color[clsid] = color_list[clsid]
|
||||
color = tuple(clsid2color[clsid])
|
||||
|
||||
if len(bbox) == 4:
|
||||
xmin, ymin, xmax, ymax = bbox
|
||||
print('class_id:{:d}, confidence:{:.4f}, left_top:[{:.2f},{:.2f}],'
|
||||
'right_bottom:[{:.2f},{:.2f}]'.format(
|
||||
int(clsid), score, xmin, ymin, xmax, ymax))
|
||||
# draw bbox
|
||||
draw.line(
|
||||
[(xmin, ymin), (xmin, ymax), (xmax, ymax), (xmax, ymin),
|
||||
(xmin, ymin)],
|
||||
width=draw_thickness,
|
||||
fill=color)
|
||||
elif len(bbox) == 8:
|
||||
x1, y1, x2, y2, x3, y3, x4, y4 = bbox
|
||||
draw.line(
|
||||
[(x1, y1), (x2, y2), (x3, y3), (x4, y4), (x1, y1)],
|
||||
width=2,
|
||||
fill=color)
|
||||
xmin = min(x1, x2, x3, x4)
|
||||
ymin = min(y1, y2, y3, y4)
|
||||
|
||||
# draw label
|
||||
text = "{} {:.4f}".format(labels[clsid], score)
|
||||
tw, th = imagedraw_textsize_c(draw, text)
|
||||
draw.rectangle(
|
||||
[(xmin + 1, ymin - th), (xmin + tw + 1, ymin)], fill=color)
|
||||
draw.text((xmin + 1, ymin - th), text, fill=(255, 255, 255))
|
||||
return im
|
||||
|
||||
|
||||
def draw_segm(im,
|
||||
np_segms,
|
||||
np_label,
|
||||
np_score,
|
||||
labels,
|
||||
threshold=0.5,
|
||||
alpha=0.7):
|
||||
"""
|
||||
Draw segmentation on image
|
||||
"""
|
||||
mask_color_id = 0
|
||||
w_ratio = .4
|
||||
color_list = get_color_map_list(len(labels))
|
||||
im = np.array(im).astype('float32')
|
||||
clsid2color = {}
|
||||
np_segms = np_segms.astype(np.uint8)
|
||||
for i in range(np_segms.shape[0]):
|
||||
mask, score, clsid = np_segms[i], np_score[i], np_label[i]
|
||||
if score < threshold:
|
||||
continue
|
||||
|
||||
if clsid not in clsid2color:
|
||||
clsid2color[clsid] = color_list[clsid]
|
||||
color_mask = clsid2color[clsid]
|
||||
for c in range(3):
|
||||
color_mask[c] = color_mask[c] * (1 - w_ratio) + w_ratio * 255
|
||||
idx = np.nonzero(mask)
|
||||
color_mask = np.array(color_mask)
|
||||
idx0 = np.minimum(idx[0], im.shape[0] - 1)
|
||||
idx1 = np.minimum(idx[1], im.shape[1] - 1)
|
||||
im[idx0, idx1, :] *= 1.0 - alpha
|
||||
im[idx0, idx1, :] += alpha * color_mask
|
||||
sum_x = np.sum(mask, axis=0)
|
||||
x = np.where(sum_x > 0.5)[0]
|
||||
sum_y = np.sum(mask, axis=1)
|
||||
y = np.where(sum_y > 0.5)[0]
|
||||
x0, x1, y0, y1 = x[0], x[-1], y[0], y[-1]
|
||||
cv2.rectangle(im, (x0, y0), (x1, y1),
|
||||
tuple(color_mask.astype('int32').tolist()), 1)
|
||||
bbox_text = '%s %.2f' % (labels[clsid], score)
|
||||
t_size = cv2.getTextSize(bbox_text, 0, 0.3, thickness=1)[0]
|
||||
cv2.rectangle(im, (x0, y0), (x0 + t_size[0], y0 - t_size[1] - 3),
|
||||
tuple(color_mask.astype('int32').tolist()), -1)
|
||||
cv2.putText(
|
||||
im,
|
||||
bbox_text, (x0, y0 - 2),
|
||||
cv2.FONT_HERSHEY_SIMPLEX,
|
||||
0.3, (0, 0, 0),
|
||||
1,
|
||||
lineType=cv2.LINE_AA)
|
||||
return Image.fromarray(im.astype('uint8'))
|
||||
|
||||
|
||||
def get_color(idx):
|
||||
idx = idx * 3
|
||||
color = ((37 * idx) % 255, (17 * idx) % 255, (29 * idx) % 255)
|
||||
return color
|
||||
|
||||
|
||||
def visualize_pose(imgfile,
|
||||
results,
|
||||
visual_thresh=0.6,
|
||||
save_name='pose.jpg',
|
||||
save_dir='output',
|
||||
returnimg=False,
|
||||
ids=None):
|
||||
try:
|
||||
import matplotlib.pyplot as plt
|
||||
import matplotlib
|
||||
plt.switch_backend('agg')
|
||||
except Exception as e:
|
||||
print('Matplotlib not found, please install matplotlib.'
|
||||
'for example: `pip install matplotlib`.')
|
||||
raise e
|
||||
skeletons, scores = results['keypoint']
|
||||
skeletons = np.array(skeletons)
|
||||
kpt_nums = 17
|
||||
if len(skeletons) > 0:
|
||||
kpt_nums = skeletons.shape[1]
|
||||
if kpt_nums == 17: #plot coco keypoint
|
||||
EDGES = [(0, 1), (0, 2), (1, 3), (2, 4), (3, 5), (4, 6), (5, 7), (6, 8),
|
||||
(7, 9), (8, 10), (5, 11), (6, 12), (11, 13), (12, 14),
|
||||
(13, 15), (14, 16), (11, 12)]
|
||||
else: #plot mpii keypoint
|
||||
EDGES = [(0, 1), (1, 2), (3, 4), (4, 5), (2, 6), (3, 6), (6, 7), (7, 8),
|
||||
(8, 9), (10, 11), (11, 12), (13, 14), (14, 15), (8, 12),
|
||||
(8, 13)]
|
||||
NUM_EDGES = len(EDGES)
|
||||
|
||||
colors = [[255, 0, 0], [255, 85, 0], [255, 170, 0], [255, 255, 0], [170, 255, 0], [85, 255, 0], [0, 255, 0], \
|
||||
[0, 255, 85], [0, 255, 170], [0, 255, 255], [0, 170, 255], [0, 85, 255], [0, 0, 255], [85, 0, 255], \
|
||||
[170, 0, 255], [255, 0, 255], [255, 0, 170], [255, 0, 85]]
|
||||
cmap = matplotlib.cm.get_cmap('hsv')
|
||||
plt.figure()
|
||||
|
||||
img = cv2.imread(imgfile) if type(imgfile) == str else imgfile
|
||||
|
||||
color_set = results['colors'] if 'colors' in results else None
|
||||
|
||||
if 'bbox' in results and ids is None:
|
||||
bboxs = results['bbox']
|
||||
for j, rect in enumerate(bboxs):
|
||||
xmin, ymin, xmax, ymax = rect
|
||||
color = colors[0] if color_set is None else colors[color_set[j] %
|
||||
len(colors)]
|
||||
cv2.rectangle(img, (xmin, ymin), (xmax, ymax), color, 1)
|
||||
|
||||
canvas = img.copy()
|
||||
for i in range(kpt_nums):
|
||||
for j in range(len(skeletons)):
|
||||
if skeletons[j][i, 2] < visual_thresh:
|
||||
continue
|
||||
if ids is None:
|
||||
color = colors[i] if color_set is None else colors[color_set[j]
|
||||
%
|
||||
len(colors)]
|
||||
else:
|
||||
color = get_color(ids[j])
|
||||
|
||||
cv2.circle(
|
||||
canvas,
|
||||
tuple(skeletons[j][i, 0:2].astype('int32')),
|
||||
2,
|
||||
color,
|
||||
thickness=-1)
|
||||
|
||||
to_plot = cv2.addWeighted(img, 0.3, canvas, 0.7, 0)
|
||||
fig = matplotlib.pyplot.gcf()
|
||||
|
||||
stickwidth = 2
|
||||
|
||||
for i in range(NUM_EDGES):
|
||||
for j in range(len(skeletons)):
|
||||
edge = EDGES[i]
|
||||
if skeletons[j][edge[0], 2] < visual_thresh or skeletons[j][edge[
|
||||
1], 2] < visual_thresh:
|
||||
continue
|
||||
|
||||
cur_canvas = canvas.copy()
|
||||
X = [skeletons[j][edge[0], 1], skeletons[j][edge[1], 1]]
|
||||
Y = [skeletons[j][edge[0], 0], skeletons[j][edge[1], 0]]
|
||||
mX = np.mean(X)
|
||||
mY = np.mean(Y)
|
||||
length = ((X[0] - X[1])**2 + (Y[0] - Y[1])**2)**0.5
|
||||
angle = math.degrees(math.atan2(X[0] - X[1], Y[0] - Y[1]))
|
||||
polygon = cv2.ellipse2Poly((int(mY), int(mX)),
|
||||
(int(length / 2), stickwidth),
|
||||
int(angle), 0, 360, 1)
|
||||
if ids is None:
|
||||
color = colors[i] if color_set is None else colors[color_set[j]
|
||||
%
|
||||
len(colors)]
|
||||
else:
|
||||
color = get_color(ids[j])
|
||||
cv2.fillConvexPoly(cur_canvas, polygon, color)
|
||||
canvas = cv2.addWeighted(canvas, 0.4, cur_canvas, 0.6, 0)
|
||||
if returnimg:
|
||||
return canvas
|
||||
save_name = os.path.join(
|
||||
save_dir, os.path.splitext(os.path.basename(imgfile))[0] + '_vis.jpg')
|
||||
plt.imsave(save_name, canvas[:, :, ::-1])
|
||||
print("keypoint visualize image saved to: " + save_name)
|
||||
plt.close()
|
||||
|
||||
|
||||
def visualize_attr(im, results, boxes=None, is_mtmct=False):
|
||||
if isinstance(im, str):
|
||||
im = Image.open(im)
|
||||
im = np.ascontiguousarray(np.copy(im))
|
||||
im = cv2.cvtColor(im, cv2.COLOR_RGB2BGR)
|
||||
else:
|
||||
im = np.ascontiguousarray(np.copy(im))
|
||||
|
||||
im_h, im_w = im.shape[:2]
|
||||
text_scale = max(0.5, im.shape[0] / 3000.)
|
||||
text_thickness = 1
|
||||
|
||||
line_inter = im.shape[0] / 40.
|
||||
for i, res in enumerate(results):
|
||||
if boxes is None:
|
||||
text_w = 3
|
||||
text_h = 1
|
||||
elif is_mtmct:
|
||||
box = boxes[i] # multi camera, bbox shape is x,y, w,h
|
||||
text_w = int(box[0]) + 3
|
||||
text_h = int(box[1])
|
||||
else:
|
||||
box = boxes[i] # single camera, bbox shape is 0, 0, x,y, w,h
|
||||
text_w = int(box[2]) + 3
|
||||
text_h = int(box[3])
|
||||
for text in res:
|
||||
text_h += int(line_inter)
|
||||
text_loc = (text_w, text_h)
|
||||
cv2.putText(
|
||||
im,
|
||||
text,
|
||||
text_loc,
|
||||
cv2.FONT_ITALIC,
|
||||
text_scale, (0, 255, 255),
|
||||
thickness=text_thickness)
|
||||
return im
|
||||
|
||||
|
||||
def visualize_action(im,
|
||||
mot_boxes,
|
||||
action_visual_collector=None,
|
||||
action_text="",
|
||||
video_action_score=None,
|
||||
video_action_text=""):
|
||||
im = cv2.imread(im) if isinstance(im, str) else im
|
||||
im_h, im_w = im.shape[:2]
|
||||
|
||||
text_scale = max(1, im.shape[1] / 400.)
|
||||
text_thickness = 2
|
||||
|
||||
if action_visual_collector:
|
||||
id_action_dict = {}
|
||||
for collector, action_type in zip(action_visual_collector, action_text):
|
||||
id_detected = collector.get_visualize_ids()
|
||||
for pid in id_detected:
|
||||
id_action_dict[pid] = id_action_dict.get(pid, [])
|
||||
id_action_dict[pid].append(action_type)
|
||||
for mot_box in mot_boxes:
|
||||
# mot_box is a format with [mot_id, class, score, xmin, ymin, w, h]
|
||||
if mot_box[0] in id_action_dict:
|
||||
text_position = (int(mot_box[3] + mot_box[5] * 0.75),
|
||||
int(mot_box[4] - 10))
|
||||
display_text = ', '.join(id_action_dict[mot_box[0]])
|
||||
cv2.putText(im, display_text, text_position,
|
||||
cv2.FONT_HERSHEY_PLAIN, text_scale, (0, 0, 255), 2)
|
||||
|
||||
if video_action_score:
|
||||
cv2.putText(
|
||||
im,
|
||||
video_action_text + ': %.2f' % video_action_score,
|
||||
(int(im_w / 2), int(15 * text_scale) + 5),
|
||||
cv2.FONT_ITALIC,
|
||||
text_scale, (0, 0, 255),
|
||||
thickness=text_thickness)
|
||||
|
||||
return im
|
||||
|
||||
|
||||
def visualize_vehicleplate(im, results, boxes=None):
|
||||
if isinstance(im, str):
|
||||
im = Image.open(im)
|
||||
im = np.ascontiguousarray(np.copy(im))
|
||||
im = cv2.cvtColor(im, cv2.COLOR_RGB2BGR)
|
||||
else:
|
||||
im = np.ascontiguousarray(np.copy(im))
|
||||
|
||||
im_h, im_w = im.shape[:2]
|
||||
text_scale = max(1.0, im.shape[0] / 400.)
|
||||
text_thickness = 2
|
||||
|
||||
line_inter = im.shape[0] / 40.
|
||||
for i, res in enumerate(results):
|
||||
if boxes is None:
|
||||
text_w = 3
|
||||
text_h = 1
|
||||
else:
|
||||
box = boxes[i]
|
||||
text = res
|
||||
if text == "":
|
||||
continue
|
||||
text_w = int(box[2])
|
||||
text_h = int(box[5] + box[3])
|
||||
text_loc = (text_w, text_h)
|
||||
cv2.putText(
|
||||
im,
|
||||
"LP: " + text,
|
||||
text_loc,
|
||||
cv2.FONT_ITALIC,
|
||||
text_scale, (0, 255, 255),
|
||||
thickness=text_thickness)
|
||||
return im
|
||||
|
||||
|
||||
def draw_press_box_lanes(im, np_boxes, labels, threshold=0.5):
|
||||
"""
|
||||
Args:
|
||||
im (PIL.Image.Image): PIL image
|
||||
np_boxes (np.ndarray): shape:[N,6], N: number of box,
|
||||
matix element:[class, score, x_min, y_min, x_max, y_max]
|
||||
labels (list): labels:['class1', ..., 'classn']
|
||||
threshold (float): threshold of box
|
||||
Returns:
|
||||
im (PIL.Image.Image): visualized image
|
||||
"""
|
||||
|
||||
if isinstance(im, str):
|
||||
im = Image.open(im).convert('RGB')
|
||||
elif isinstance(im, np.ndarray):
|
||||
im = Image.fromarray(im)
|
||||
|
||||
draw_thickness = min(im.size) // 320
|
||||
draw = ImageDraw.Draw(im)
|
||||
clsid2color = {}
|
||||
color_list = get_color_map_list(len(labels))
|
||||
|
||||
if np_boxes.shape[1] == 7:
|
||||
np_boxes = np_boxes[:, 1:]
|
||||
|
||||
expect_boxes = (np_boxes[:, 1] > threshold) & (np_boxes[:, 0] > -1)
|
||||
np_boxes = np_boxes[expect_boxes, :]
|
||||
|
||||
for dt in np_boxes:
|
||||
clsid, bbox, score = int(dt[0]), dt[2:], dt[1]
|
||||
if clsid not in clsid2color:
|
||||
clsid2color[clsid] = color_list[clsid]
|
||||
color = tuple(clsid2color[clsid])
|
||||
|
||||
if len(bbox) == 4:
|
||||
xmin, ymin, xmax, ymax = bbox
|
||||
# draw bbox
|
||||
draw.line(
|
||||
[(xmin, ymin), (xmin, ymax), (xmax, ymax), (xmax, ymin),
|
||||
(xmin, ymin)],
|
||||
width=draw_thickness,
|
||||
fill=(0, 0, 255))
|
||||
elif len(bbox) == 8:
|
||||
x1, y1, x2, y2, x3, y3, x4, y4 = bbox
|
||||
draw.line(
|
||||
[(x1, y1), (x2, y2), (x3, y3), (x4, y4), (x1, y1)],
|
||||
width=2,
|
||||
fill=color)
|
||||
xmin = min(x1, x2, x3, x4)
|
||||
ymin = min(y1, y2, y3, y4)
|
||||
|
||||
# draw label
|
||||
text = "{}".format(labels[clsid])
|
||||
tw, th = imagedraw_textsize_c(draw, text)
|
||||
draw.rectangle(
|
||||
[(xmin + 1, ymax - th), (xmin + tw + 1, ymax)], fill=color)
|
||||
draw.text((xmin + 1, ymax - th), text, fill=(0, 0, 255))
|
||||
return im
|
||||
|
||||
|
||||
def visualize_vehiclepress(im, results, threshold=0.5):
|
||||
results = np.array(results)
|
||||
labels = ['violation']
|
||||
im = draw_press_box_lanes(im, results, labels, threshold=threshold)
|
||||
return im
|
||||
|
||||
|
||||
def visualize_lane(im, lanes):
|
||||
if isinstance(im, str):
|
||||
im = Image.open(im).convert('RGB')
|
||||
elif isinstance(im, np.ndarray):
|
||||
im = Image.fromarray(im)
|
||||
|
||||
draw_thickness = min(im.size) // 320
|
||||
draw = ImageDraw.Draw(im)
|
||||
|
||||
if len(lanes) > 0:
|
||||
for lane in lanes:
|
||||
draw.line(
|
||||
[(lane[0], lane[1]), (lane[2], lane[3])],
|
||||
width=draw_thickness,
|
||||
fill=(0, 0, 255))
|
||||
|
||||
return im
|
||||
|
||||
|
||||
def visualize_vehicle_retrograde(im, mot_res, vehicle_retrograde_res):
|
||||
if isinstance(im, str):
|
||||
im = Image.open(im).convert('RGB')
|
||||
elif isinstance(im, np.ndarray):
|
||||
im = Image.fromarray(im)
|
||||
|
||||
draw_thickness = min(im.size) // 320
|
||||
draw = ImageDraw.Draw(im)
|
||||
|
||||
lane = vehicle_retrograde_res['fence_line']
|
||||
if lane is not None:
|
||||
draw.line(
|
||||
[(lane[0], lane[1]), (lane[2], lane[3])],
|
||||
width=draw_thickness,
|
||||
fill=(0, 0, 0))
|
||||
|
||||
mot_id = vehicle_retrograde_res['output']
|
||||
if mot_id is None or len(mot_id) == 0:
|
||||
return im
|
||||
|
||||
if mot_res is None:
|
||||
return im
|
||||
np_boxes = mot_res['boxes']
|
||||
|
||||
if np_boxes is not None:
|
||||
for dt in np_boxes:
|
||||
if dt[0] not in mot_id:
|
||||
continue
|
||||
bbox = dt[3:]
|
||||
if len(bbox) == 4:
|
||||
xmin, ymin, xmax, ymax = bbox
|
||||
# draw bbox
|
||||
draw.line(
|
||||
[(xmin, ymin), (xmin, ymax), (xmax, ymax), (xmax, ymin),
|
||||
(xmin, ymin)],
|
||||
width=draw_thickness,
|
||||
fill=(0, 255, 0))
|
||||
|
||||
# draw label
|
||||
text = "retrograde"
|
||||
tw, th = imagedraw_textsize_c(draw, text)
|
||||
draw.rectangle(
|
||||
[(xmax + 1, ymin - th), (xmax + tw + 1, ymin)],
|
||||
fill=(0, 255, 0))
|
||||
draw.text((xmax + 1, ymin - th), text, fill=(0, 255, 0))
|
||||
|
||||
return im
|
||||
|
||||
|
||||
COLORS = [
|
||||
(255, 0, 0),
|
||||
(0, 255, 0),
|
||||
(0, 0, 255),
|
||||
(255, 255, 0),
|
||||
(255, 0, 255),
|
||||
(0, 255, 255),
|
||||
(128, 255, 0),
|
||||
(255, 128, 0),
|
||||
(128, 0, 255),
|
||||
(255, 0, 128),
|
||||
(0, 128, 255),
|
||||
(0, 255, 128),
|
||||
(128, 255, 255),
|
||||
(255, 128, 255),
|
||||
(255, 255, 128),
|
||||
(60, 180, 0),
|
||||
(180, 60, 0),
|
||||
(0, 60, 180),
|
||||
(0, 180, 60),
|
||||
(60, 0, 180),
|
||||
(180, 0, 60),
|
||||
(255, 0, 0),
|
||||
(0, 255, 0),
|
||||
(0, 0, 255),
|
||||
(255, 255, 0),
|
||||
(255, 0, 255),
|
||||
(0, 255, 255),
|
||||
(128, 255, 0),
|
||||
(255, 128, 0),
|
||||
(128, 0, 255),
|
||||
]
|
||||
|
||||
|
||||
def imshow_lanes(img, lanes, show=False, out_file=None, width=4):
|
||||
lanes_xys = []
|
||||
for _, lane in enumerate(lanes):
|
||||
xys = []
|
||||
for x, y in lane:
|
||||
if x <= 0 or y <= 0:
|
||||
continue
|
||||
x, y = int(x), int(y)
|
||||
xys.append((x, y))
|
||||
lanes_xys.append(xys)
|
||||
lanes_xys.sort(key=lambda xys: xys[0][0] if len(xys) > 0 else 0)
|
||||
|
||||
for idx, xys in enumerate(lanes_xys):
|
||||
for i in range(1, len(xys)):
|
||||
cv2.line(img, xys[i - 1], xys[i], COLORS[idx], thickness=width)
|
||||
|
||||
if show:
|
||||
cv2.imshow('view', img)
|
||||
cv2.waitKey(0)
|
||||
|
||||
if out_file:
|
||||
if not os.path.exists(os.path.dirname(out_file)):
|
||||
os.makedirs(os.path.dirname(out_file))
|
||||
cv2.imwrite(out_file, img)
|
@ -0,0 +1,62 @@
|
||||
{
|
||||
"bucket_info": {
|
||||
"bucket-name-1": [
|
||||
"ak",
|
||||
"sk",
|
||||
"endpoint"
|
||||
],
|
||||
"bucket-name-2": [
|
||||
"ak",
|
||||
"sk",
|
||||
"endpoint"
|
||||
]
|
||||
},
|
||||
"models-dir": "/root/.cache/modelscope/hub/models/opendatalab/PDF-Extract-Kit-1___0/models",
|
||||
"layoutreader-model-dir": "/root/.cache/modelscope/hub/models/ppaanngggg/layoutreader",
|
||||
"device-mode": "cuda",
|
||||
"layout-config": {
|
||||
"model": "doclayout_yolo"
|
||||
},
|
||||
"formula-config": {
|
||||
"mfd_model": "yolo_v8_mfd",
|
||||
"mfr_model": "unimernet_small",
|
||||
"enable": false
|
||||
},
|
||||
"table-config": {
|
||||
"model": "rapid_table",
|
||||
"sub_model": "slanet_plus",
|
||||
"enable": false,
|
||||
"max_time": 400
|
||||
},
|
||||
"latex-delimiter-config": {
|
||||
"display": {
|
||||
"left": "$$",
|
||||
"right": "$$"
|
||||
},
|
||||
"inline": {
|
||||
"left": "$",
|
||||
"right": "$"
|
||||
}
|
||||
},
|
||||
"llm-aided-config": {
|
||||
"formula_aided": {
|
||||
"api_key": "your_api_key",
|
||||
"base_url": "https://dashscope.aliyuncs.com/compatible-mode/v1",
|
||||
"model": "qwen2.5-7b-instruct",
|
||||
"enable": false
|
||||
},
|
||||
"text_aided": {
|
||||
"api_key": "your_api_key",
|
||||
"base_url": "https://dashscope.aliyuncs.com/compatible-mode/v1",
|
||||
"model": "qwen2.5-7b-instruct",
|
||||
"enable": false
|
||||
},
|
||||
"title_aided": {
|
||||
"api_key": "your_api_key",
|
||||
"base_url": "https://dashscope.aliyuncs.com/compatible-mode/v1",
|
||||
"model": "qwen2.5-32b-instruct",
|
||||
"enable": false
|
||||
}
|
||||
},
|
||||
"config_version": "1.2.1"
|
||||
}
|
@ -0,0 +1,48 @@
|
||||
mode: paddle
|
||||
draw_threshold: 0.5
|
||||
metric: COCO
|
||||
use_dynamic_shape: false
|
||||
arch: PicoDet
|
||||
min_subgraph_size: 3
|
||||
Preprocess:
|
||||
- interp: 2
|
||||
keep_ratio: false
|
||||
target_size:
|
||||
- 800
|
||||
- 608
|
||||
type: Resize
|
||||
- is_scale: true
|
||||
mean:
|
||||
- 0.485
|
||||
- 0.456
|
||||
- 0.406
|
||||
std:
|
||||
- 0.229
|
||||
- 0.224
|
||||
- 0.225
|
||||
type: NormalizeImage
|
||||
- type: Permute
|
||||
- stride: 32
|
||||
type: PadStride
|
||||
label_list:
|
||||
- Text
|
||||
- Title
|
||||
- Figure
|
||||
- Figure caption
|
||||
- Table
|
||||
- Table caption
|
||||
- Header
|
||||
- Footer
|
||||
- Reference
|
||||
- Equation
|
||||
NMS:
|
||||
keep_top_k: 100
|
||||
name: MultiClassNMS
|
||||
nms_threshold: 0.5
|
||||
nms_top_k: 1000
|
||||
score_threshold: 0.3
|
||||
fpn_stride:
|
||||
- 8
|
||||
- 16
|
||||
- 32
|
||||
- 64
|
Binary file not shown.
Binary file not shown.
Binary file not shown.
@ -0,0 +1,43 @@
|
||||
mode: paddle
|
||||
draw_threshold: 0.5
|
||||
metric: COCO
|
||||
use_dynamic_shape: false
|
||||
arch: PicoDet
|
||||
min_subgraph_size: 3
|
||||
Preprocess:
|
||||
- interp: 2
|
||||
keep_ratio: false
|
||||
target_size:
|
||||
- 800
|
||||
- 608
|
||||
type: Resize
|
||||
- is_scale: true
|
||||
mean:
|
||||
- 0.485
|
||||
- 0.456
|
||||
- 0.406
|
||||
std:
|
||||
- 0.229
|
||||
- 0.224
|
||||
- 0.225
|
||||
type: NormalizeImage
|
||||
- type: Permute
|
||||
- stride: 32
|
||||
type: PadStride
|
||||
label_list:
|
||||
- text
|
||||
- title
|
||||
- list
|
||||
- table
|
||||
- figure
|
||||
NMS:
|
||||
keep_top_k: 100
|
||||
name: MultiClassNMS
|
||||
nms_threshold: 0.5
|
||||
nms_top_k: 1000
|
||||
score_threshold: 0.3
|
||||
fpn_stride:
|
||||
- 8
|
||||
- 16
|
||||
- 32
|
||||
- 64
|
Binary file not shown.
Binary file not shown.
Binary file not shown.
@ -0,0 +1,39 @@
|
||||
mode: paddle
|
||||
draw_threshold: 0.5
|
||||
metric: COCO
|
||||
use_dynamic_shape: false
|
||||
arch: PicoDet
|
||||
min_subgraph_size: 3
|
||||
Preprocess:
|
||||
- interp: 2
|
||||
keep_ratio: false
|
||||
target_size:
|
||||
- 800
|
||||
- 608
|
||||
type: Resize
|
||||
- is_scale: true
|
||||
mean:
|
||||
- 0.485
|
||||
- 0.456
|
||||
- 0.406
|
||||
std:
|
||||
- 0.229
|
||||
- 0.224
|
||||
- 0.225
|
||||
type: NormalizeImage
|
||||
- type: Permute
|
||||
- stride: 32
|
||||
type: PadStride
|
||||
label_list:
|
||||
- table
|
||||
NMS:
|
||||
keep_top_k: 100
|
||||
name: MultiClassNMS
|
||||
nms_threshold: 0.5
|
||||
nms_top_k: 1000
|
||||
score_threshold: 0.3
|
||||
fpn_stride:
|
||||
- 8
|
||||
- 16
|
||||
- 32
|
||||
- 64
|
Binary file not shown.
Binary file not shown.
Binary file not shown.
@ -0,0 +1,43 @@
|
||||
mode: paddle
|
||||
draw_threshold: 0.5
|
||||
metric: COCO
|
||||
use_dynamic_shape: false
|
||||
arch: PicoDet
|
||||
min_subgraph_size: 3
|
||||
Preprocess:
|
||||
- interp: 2
|
||||
keep_ratio: false
|
||||
target_size:
|
||||
- 800
|
||||
- 608
|
||||
type: Resize
|
||||
- is_scale: true
|
||||
mean:
|
||||
- 0.485
|
||||
- 0.456
|
||||
- 0.406
|
||||
std:
|
||||
- 0.229
|
||||
- 0.224
|
||||
- 0.225
|
||||
type: NormalizeImage
|
||||
- type: Permute
|
||||
- stride: 32
|
||||
type: PadStride
|
||||
label_list:
|
||||
- text
|
||||
- title
|
||||
- list
|
||||
- table
|
||||
- figure
|
||||
NMS:
|
||||
keep_top_k: 100
|
||||
name: MultiClassNMS
|
||||
nms_threshold: 0.5
|
||||
nms_top_k: 1000
|
||||
score_threshold: 0.3
|
||||
fpn_stride:
|
||||
- 8
|
||||
- 16
|
||||
- 32
|
||||
- 64
|
Binary file not shown.
Binary file not shown.
Binary file not shown.
@ -0,0 +1,94 @@
|
||||
from dotenv import load_dotenv
|
||||
import os
|
||||
|
||||
env = os.environ.get('env', 'dev')
|
||||
load_dotenv(dotenv_path='.env.dev' if env == 'dev' else '.env', override=True)
|
||||
|
||||
import time
|
||||
import traceback
|
||||
import cv2
|
||||
from helper.image_helper import pdf2image, image_orient_cls, page_detection_visual
|
||||
from helper.page_detection.main import layout_analysis
|
||||
from helper.content_recognition.main import rec
|
||||
from helper.db_helper import insert_pdf2md_table
|
||||
import tempfile
|
||||
from loguru import logger
|
||||
import datetime
|
||||
import shutil
|
||||
|
||||
|
||||
def _pdf2markdown_pipeline(pdf_path, tmp_dir):
|
||||
start_time = time.time()
|
||||
# 1. pdf -> images
|
||||
t1 = time.time()
|
||||
pdf2image(pdf_path, tmp_dir)
|
||||
t2 = time.time()
|
||||
|
||||
# 2. 图片方向分类
|
||||
t3 = time.time()
|
||||
orient_cls_results = image_orient_cls(tmp_dir)
|
||||
t4 = time.time()
|
||||
for r in orient_cls_results:
|
||||
clsid = r[0]['class_ids'][0]
|
||||
filename = r[0]['filename']
|
||||
if clsid == 1 or clsid == 3:
|
||||
img = cv2.imread(filename)
|
||||
img = cv2.rotate(img, cv2.ROTATE_90_CLOCKWISE)
|
||||
cv2.imwrite(filename, img)
|
||||
|
||||
filepaths = os.listdir(tmp_dir)
|
||||
filepaths.sort(key=lambda x: int(x.split('/')[-1].split('.')[0]))
|
||||
filepaths = [f'{tmp_dir}/{_}' for _ in filepaths]
|
||||
|
||||
# filepaths = filepaths[:75]
|
||||
|
||||
# 3. 版面分析
|
||||
t5 = time.time()
|
||||
layout_detection_results = layout_analysis(filepaths)
|
||||
t6 = time.time()
|
||||
|
||||
# 3.1 visual
|
||||
if int(os.environ['VISUAL']):
|
||||
visual_dir = './visual_images'
|
||||
for f in os.listdir(visual_dir):
|
||||
if f.endswith('.jpg'):
|
||||
os.remove(f'{visual_dir}/{f}')
|
||||
for i in range(len(layout_detection_results)):
|
||||
vis_img = page_detection_visual(layout_detection_results[i])
|
||||
cv2.imwrite(f'{visual_dir}/{i + 1}.jpg', vis_img)
|
||||
|
||||
# 4. 内容识别
|
||||
t7 = time.time()
|
||||
layout_recognition_results = rec(layout_detection_results, tmp_dir)
|
||||
t8 = time.time()
|
||||
|
||||
end_time = time.time()
|
||||
logger.info(f'{pdf_path} analysis completed in {round(end_time - start_time, 3)} seconds, including {round(t2 - t1, 3)} for pdf to image, {round(t4 - t3, 3)} second for image orient classification, {round(t6 - t5, 3)} seconds for page detection, and {round(t8 - t7, 3)} seconds for layout recognition, page number: {len(filepaths)}')
|
||||
|
||||
return layout_recognition_results
|
||||
|
||||
|
||||
def pdf2markdown_pipeline(pdf_path: str):
|
||||
pdf_name = pdf_path.split('/')[-1]
|
||||
start_time = datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S.%f')
|
||||
process_status = 0
|
||||
tmp_dir = tempfile.mkdtemp()
|
||||
try:
|
||||
results = _pdf2markdown_pipeline(pdf_path, tmp_dir)
|
||||
except Exception:
|
||||
logger.error(f'analysis pdf error! \n{traceback.format_exc()}')
|
||||
process_status = 3
|
||||
end_time = datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S.%f')
|
||||
insert_pdf2md_table(pdf_path, pdf_name, process_status, start_time, end_time, None)
|
||||
pdf_id = None
|
||||
else:
|
||||
process_status = 2
|
||||
end_time = datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S.%f')
|
||||
pdf_id = insert_pdf2md_table(pdf_path, pdf_name, process_status, start_time, end_time, results)
|
||||
finally:
|
||||
shutil.rmtree(tmp_dir)
|
||||
return process_status, pdf_id
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
pdf2markdown_pipeline('/mnt/pdf2markdown/demo.pdf')
|
@ -0,0 +1,189 @@
|
||||
albucore==0.0.24
|
||||
albumentations==2.0.6
|
||||
annotated-types==0.7.0
|
||||
anthropic==0.46.0
|
||||
antlr4-python3-runtime==4.9.3
|
||||
anyio==4.9.0
|
||||
astor==0.8.1
|
||||
babel==2.17.0
|
||||
bce-python-sdk==0.9.29
|
||||
beautifulsoup4==4.13.4
|
||||
blinker==1.9.0
|
||||
boto3==1.38.8
|
||||
botocore==1.38.8
|
||||
Brotli==1.1.0
|
||||
cachetools==5.5.2
|
||||
certifi==2025.4.26
|
||||
cffi==1.17.1
|
||||
cfgv==3.4.0
|
||||
charset-normalizer==3.4.1
|
||||
click==8.1.8
|
||||
cobble==0.1.4
|
||||
coloredlogs==15.0.1
|
||||
colorlog==6.9.0
|
||||
contourpy==1.3.2
|
||||
cryptography==44.0.3
|
||||
cssselect2==0.8.0
|
||||
cycler==0.12.1
|
||||
Cython==3.0.12
|
||||
decorator==5.2.1
|
||||
dill==0.4.0
|
||||
distlib==0.3.9
|
||||
distro==1.9.0
|
||||
doclayout_yolo==0.0.2b1
|
||||
easydict==1.13
|
||||
EbookLib==0.18
|
||||
et_xmlfile==2.0.0
|
||||
faiss-cpu==1.8.0.post1
|
||||
fast-langdetect==0.2.5
|
||||
fasttext-predict==0.9.2.4
|
||||
filelock==3.18.0
|
||||
filetype==1.2.0
|
||||
fire==0.7.0
|
||||
Flask==3.1.0
|
||||
flask-babel==4.0.0
|
||||
flatbuffers==25.2.10
|
||||
fonttools==4.57.0
|
||||
fsspec==2025.3.2
|
||||
ftfy==6.3.1
|
||||
future==1.0.0
|
||||
gast==0.3.3
|
||||
google-auth==2.39.0
|
||||
google-genai==1.13.0
|
||||
h11==0.16.0
|
||||
httpcore==1.0.9
|
||||
httpx==0.28.1
|
||||
huggingface-hub==0.30.2
|
||||
humanfriendly==10.0
|
||||
identify==2.6.10
|
||||
idna==3.10
|
||||
imageio==2.37.0
|
||||
imgaug==0.4.0
|
||||
itsdangerous==2.2.0
|
||||
Jinja2==3.1.6
|
||||
jiter==0.9.0
|
||||
jmespath==1.0.1
|
||||
joblib==1.4.2
|
||||
kiwisolver==1.4.8
|
||||
lazy_loader==0.4
|
||||
lmdb==1.6.2
|
||||
loguru==0.7.3
|
||||
lxml==5.4.0
|
||||
magic-pdf==1.3.10
|
||||
mammoth==1.9.0
|
||||
markdown2==2.5.3
|
||||
markdownify==0.13.1
|
||||
marker-pdf==1.6.2
|
||||
MarkupSafe==3.0.2
|
||||
matplotlib==3.10.1
|
||||
modelscope==1.25.0
|
||||
mpmath==1.3.0
|
||||
networkx==3.4.2
|
||||
nodeenv==1.9.1
|
||||
numpy==1.24.4
|
||||
nvidia-cublas-cu12==12.4.5.8
|
||||
nvidia-cuda-cupti-cu12==12.4.127
|
||||
nvidia-cuda-nvrtc-cu12==12.4.127
|
||||
nvidia-cuda-runtime-cu12==12.4.127
|
||||
nvidia-cudnn-cu12==9.1.0.70
|
||||
nvidia-cufft-cu12==11.2.1.3
|
||||
nvidia-curand-cu12==10.3.5.147
|
||||
nvidia-cusolver-cu12==11.6.1.9
|
||||
nvidia-cusparse-cu12==12.3.1.170
|
||||
nvidia-cusparselt-cu12==0.6.2
|
||||
nvidia-nccl-cu12==2.21.5
|
||||
nvidia-nvjitlink-cu12==12.4.127
|
||||
nvidia-nvtx-cu12==12.4.127
|
||||
omegaconf==2.3.0
|
||||
onnxruntime==1.21.1
|
||||
openai==1.77.0
|
||||
opencv-contrib-python==4.11.0.86
|
||||
opencv-python==4.6.0.66
|
||||
opencv-python-headless==4.11.0.86
|
||||
openpyxl==3.1.5
|
||||
opt-einsum==3.3.0
|
||||
packaging==25.0
|
||||
paddleclas==2.5.2
|
||||
paddleocr==2.10.0
|
||||
paddlepaddle-gpu==2.6.2
|
||||
pandas==2.2.3
|
||||
pdf2image==1.17.0
|
||||
pdfminer.six==20250324
|
||||
pdftext==0.6.2
|
||||
pillow==10.4.0
|
||||
platformdirs==4.3.7
|
||||
pre_commit==4.2.0
|
||||
prettytable==3.16.0
|
||||
protobuf==6.30.2
|
||||
psutil==7.0.0
|
||||
psycopg2==2.9.10
|
||||
py-cpuinfo==9.0.0
|
||||
pyasn1==0.6.1
|
||||
pyasn1_modules==0.4.2
|
||||
pyclipper==1.3.0.post6
|
||||
pycparser==2.22
|
||||
pycryptodome==3.22.0
|
||||
pydantic==2.10.6
|
||||
pydantic-settings==2.9.1
|
||||
pydantic_core==2.27.2
|
||||
pydyf==0.11.0
|
||||
PyMuPDF==1.24.14
|
||||
pyparsing==3.2.3
|
||||
pypdfium2==4.30.0
|
||||
pyphen==0.17.2
|
||||
python-dateutil==2.9.0.post0
|
||||
python-docx==1.1.2
|
||||
python-dotenv==1.1.0
|
||||
python-pptx==1.0.2
|
||||
pytz==2025.2
|
||||
PyYAML==6.0.2
|
||||
rapid-table==1.0.5
|
||||
RapidFuzz==3.13.0
|
||||
rapidocr==2.0.7
|
||||
rarfile==4.2
|
||||
regex==2024.11.6
|
||||
requests==2.32.3
|
||||
robust-downloader==0.0.2
|
||||
rsa==4.9.1
|
||||
s3transfer==0.12.0
|
||||
safetensors==0.5.3
|
||||
scikit-image==0.25.2
|
||||
scikit-learn==1.6.1
|
||||
scipy==1.15.2
|
||||
seaborn==0.13.2
|
||||
shapely==2.1.0
|
||||
simsimd==6.2.1
|
||||
six==1.17.0
|
||||
sniffio==1.3.1
|
||||
soupsieve==2.7
|
||||
stringzilla==3.12.5
|
||||
surya-ocr==0.13.1
|
||||
sympy==1.13.1
|
||||
termcolor==3.1.0
|
||||
thop==0.1.1.post2209072238
|
||||
threadpoolctl==3.6.0
|
||||
tifffile==2025.3.30
|
||||
tinycss2==1.4.0
|
||||
tinyhtml5==2.0.0
|
||||
tokenizers==0.21.1
|
||||
torch==2.6.0
|
||||
torchvision==0.21.0
|
||||
tqdm==4.67.1
|
||||
transformers==4.51.3
|
||||
triton==3.2.0
|
||||
typing-inspection==0.4.0
|
||||
typing_extensions==4.13.2
|
||||
tzdata==2025.2
|
||||
ujson==5.10.0
|
||||
ultralytics==8.3.127
|
||||
ultralytics-thop==2.0.14
|
||||
urllib3==2.4.0
|
||||
virtualenv==20.31.1
|
||||
visualdl==2.5.3
|
||||
wcwidth==0.2.13
|
||||
weasyprint==63.1
|
||||
webencodings==0.5.1
|
||||
websockets==15.0.1
|
||||
Werkzeug==3.1.3
|
||||
XlsxWriter==3.2.3
|
||||
zopfli==0.2.3.post1
|
@ -0,0 +1,16 @@
|
||||
from flask import Flask, request
|
||||
import requests
|
||||
from pipeline import pdf2markdown_pipeline
|
||||
|
||||
|
||||
app = Flask(__name__)
|
||||
|
||||
|
||||
@app.route('/pdf-qa-server/pdf-to-md')
|
||||
def pdf2markdown():
|
||||
data = request.json
|
||||
pdf_paths = data['pathList']
|
||||
callback_url = data['webhookUrl']
|
||||
for pdf_path in pdf_paths:
|
||||
process_status, pdf_id = pdf2markdown_pipeline(pdf_path)
|
||||
requests.post(callback_url, json={'pdfId': pdf_id, 'processStatus': process_status})
|
Loading…
Reference in New Issue