You cannot select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
545 lines
18 KiB
Python
545 lines
18 KiB
Python
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
# -*- encoding: utf-8 -*-
|
|
# @Author: SWHL
|
|
# @Contact: liekkaskono@163.com
|
|
import os
|
|
import platform
|
|
import traceback
|
|
from enum import Enum
|
|
from pathlib import Path
|
|
from typing import Any, Dict, List, Tuple, Union
|
|
|
|
import cv2
|
|
import numpy as np
|
|
from onnxruntime import (
|
|
GraphOptimizationLevel,
|
|
InferenceSession,
|
|
SessionOptions,
|
|
get_available_providers,
|
|
get_device,
|
|
)
|
|
|
|
from rapid_table.utils import Logger
|
|
|
|
|
|
class EP(Enum):
|
|
CPU_EP = "CPUExecutionProvider"
|
|
CUDA_EP = "CUDAExecutionProvider"
|
|
DIRECTML_EP = "DmlExecutionProvider"
|
|
|
|
|
|
class OrtInferSession:
|
|
def __init__(self, config: Dict[str, Any]):
|
|
self.logger = Logger(logger_name=__name__).get_log()
|
|
|
|
model_path = config.get("model_path", None)
|
|
self._verify_model(model_path)
|
|
|
|
self.cfg_use_cuda = config.get("use_cuda", None)
|
|
self.cfg_use_dml = config.get("use_dml", None)
|
|
|
|
self.had_providers: List[str] = get_available_providers()
|
|
EP_list = self._get_ep_list()
|
|
|
|
sess_opt = self._init_sess_opts(config)
|
|
self.session = InferenceSession(
|
|
model_path,
|
|
sess_options=sess_opt,
|
|
providers=EP_list,
|
|
)
|
|
self._verify_providers()
|
|
|
|
@staticmethod
|
|
def _init_sess_opts(config: Dict[str, Any]) -> SessionOptions:
|
|
sess_opt = SessionOptions()
|
|
sess_opt.log_severity_level = 4
|
|
sess_opt.enable_cpu_mem_arena = False
|
|
sess_opt.graph_optimization_level = GraphOptimizationLevel.ORT_ENABLE_ALL
|
|
|
|
cpu_nums = os.cpu_count()
|
|
intra_op_num_threads = config.get("intra_op_num_threads", -1)
|
|
if intra_op_num_threads != -1 and 1 <= intra_op_num_threads <= cpu_nums:
|
|
sess_opt.intra_op_num_threads = intra_op_num_threads
|
|
|
|
inter_op_num_threads = config.get("inter_op_num_threads", -1)
|
|
if inter_op_num_threads != -1 and 1 <= inter_op_num_threads <= cpu_nums:
|
|
sess_opt.inter_op_num_threads = inter_op_num_threads
|
|
|
|
return sess_opt
|
|
|
|
def get_metadata(self, key: str = "character") -> list:
|
|
meta_dict = self.session.get_modelmeta().custom_metadata_map
|
|
content_list = meta_dict[key].splitlines()
|
|
return content_list
|
|
|
|
def _get_ep_list(self) -> List[Tuple[str, Dict[str, Any]]]:
|
|
cpu_provider_opts = {
|
|
"arena_extend_strategy": "kSameAsRequested",
|
|
}
|
|
EP_list = [(EP.CPU_EP.value, cpu_provider_opts)]
|
|
|
|
cuda_provider_opts = {
|
|
"device_id": 0,
|
|
"arena_extend_strategy": "kNextPowerOfTwo",
|
|
"cudnn_conv_algo_search": "EXHAUSTIVE",
|
|
"do_copy_in_default_stream": True,
|
|
}
|
|
self.use_cuda = self._check_cuda()
|
|
if self.use_cuda:
|
|
EP_list.insert(0, (EP.CUDA_EP.value, cuda_provider_opts))
|
|
|
|
self.use_directml = self._check_dml()
|
|
if self.use_directml:
|
|
self.logger.info(
|
|
"Windows 10 or above detected, try to use DirectML as primary provider"
|
|
)
|
|
directml_options = (
|
|
cuda_provider_opts if self.use_cuda else cpu_provider_opts
|
|
)
|
|
EP_list.insert(0, (EP.DIRECTML_EP.value, directml_options))
|
|
return EP_list
|
|
|
|
def _check_cuda(self) -> bool:
|
|
if not self.cfg_use_cuda:
|
|
return False
|
|
|
|
cur_device = get_device()
|
|
if cur_device == "GPU" and EP.CUDA_EP.value in self.had_providers:
|
|
return True
|
|
|
|
self.logger.warning(
|
|
"%s is not in available providers (%s). Use %s inference by default.",
|
|
EP.CUDA_EP.value,
|
|
self.had_providers,
|
|
self.had_providers[0],
|
|
)
|
|
self.logger.info("!!!Recommend to use rapidocr_paddle for inference on GPU.")
|
|
self.logger.info(
|
|
"(For reference only) If you want to use GPU acceleration, you must do:"
|
|
)
|
|
self.logger.info(
|
|
"First, uninstall all onnxruntime pakcages in current environment."
|
|
)
|
|
self.logger.info(
|
|
"Second, install onnxruntime-gpu by `pip install onnxruntime-gpu`."
|
|
)
|
|
self.logger.info(
|
|
"\tNote the onnxruntime-gpu version must match your cuda and cudnn version."
|
|
)
|
|
self.logger.info(
|
|
"\tYou can refer this link: https://onnxruntime.ai/docs/execution-providers/CUDA-EP.html"
|
|
)
|
|
self.logger.info(
|
|
"Third, ensure %s is in available providers list. e.g. ['CUDAExecutionProvider', 'CPUExecutionProvider']",
|
|
EP.CUDA_EP.value,
|
|
)
|
|
return False
|
|
|
|
def _check_dml(self) -> bool:
|
|
if not self.cfg_use_dml:
|
|
return False
|
|
|
|
cur_os = platform.system()
|
|
if cur_os != "Windows":
|
|
self.logger.warning(
|
|
"DirectML is only supported in Windows OS. The current OS is %s. Use %s inference by default.",
|
|
cur_os,
|
|
self.had_providers[0],
|
|
)
|
|
return False
|
|
|
|
cur_window_version = int(platform.release().split(".")[0])
|
|
if cur_window_version < 10:
|
|
self.logger.warning(
|
|
"DirectML is only supported in Windows 10 and above OS. The current Windows version is %s. Use %s inference by default.",
|
|
cur_window_version,
|
|
self.had_providers[0],
|
|
)
|
|
return False
|
|
|
|
if EP.DIRECTML_EP.value in self.had_providers:
|
|
return True
|
|
|
|
self.logger.warning(
|
|
"%s is not in available providers (%s). Use %s inference by default.",
|
|
EP.DIRECTML_EP.value,
|
|
self.had_providers,
|
|
self.had_providers[0],
|
|
)
|
|
self.logger.info("If you want to use DirectML acceleration, you must do:")
|
|
self.logger.info(
|
|
"First, uninstall all onnxruntime pakcages in current environment."
|
|
)
|
|
self.logger.info(
|
|
"Second, install onnxruntime-directml by `pip install onnxruntime-directml`"
|
|
)
|
|
self.logger.info(
|
|
"Third, ensure %s is in available providers list. e.g. ['DmlExecutionProvider', 'CPUExecutionProvider']",
|
|
EP.DIRECTML_EP.value,
|
|
)
|
|
return False
|
|
|
|
def _verify_providers(self):
|
|
session_providers = self.session.get_providers()
|
|
first_provider = session_providers[0]
|
|
|
|
if self.use_cuda and first_provider != EP.CUDA_EP.value:
|
|
self.logger.warning(
|
|
"%s is not avaiable for current env, the inference part is automatically shifted to be executed under %s.",
|
|
EP.CUDA_EP.value,
|
|
first_provider,
|
|
)
|
|
|
|
if self.use_directml and first_provider != EP.DIRECTML_EP.value:
|
|
self.logger.warning(
|
|
"%s is not available for current env, the inference part is automatically shifted to be executed under %s.",
|
|
EP.DIRECTML_EP.value,
|
|
first_provider,
|
|
)
|
|
|
|
def __call__(self, input_content: List[np.ndarray]) -> np.ndarray:
|
|
input_dict = dict(zip(self.get_input_names(), input_content))
|
|
try:
|
|
return self.session.run(None, input_dict)
|
|
except Exception as e:
|
|
error_info = traceback.format_exc()
|
|
raise ONNXRuntimeError(error_info) from e
|
|
|
|
def get_input_names(self) -> List[str]:
|
|
return [v.name for v in self.session.get_inputs()]
|
|
|
|
def get_output_names(self) -> List[str]:
|
|
return [v.name for v in self.session.get_outputs()]
|
|
|
|
def get_character_list(self, key: str = "character") -> List[str]:
|
|
meta_dict = self.session.get_modelmeta().custom_metadata_map
|
|
return meta_dict[key].splitlines()
|
|
|
|
def have_key(self, key: str = "character") -> bool:
|
|
meta_dict = self.session.get_modelmeta().custom_metadata_map
|
|
if key in meta_dict.keys():
|
|
return True
|
|
return False
|
|
|
|
@staticmethod
|
|
def _verify_model(model_path: Union[str, Path, None]):
|
|
if model_path is None:
|
|
raise ValueError("model_path is None!")
|
|
|
|
model_path = Path(model_path)
|
|
if not model_path.exists():
|
|
raise FileNotFoundError(f"{model_path} does not exists.")
|
|
|
|
if not model_path.is_file():
|
|
raise FileExistsError(f"{model_path} is not a file.")
|
|
|
|
|
|
class ONNXRuntimeError(Exception):
|
|
pass
|
|
|
|
|
|
class TableLabelDecode:
|
|
def __init__(self, dict_character, merge_no_span_structure=True, **kwargs):
|
|
if merge_no_span_structure:
|
|
if "<td></td>" not in dict_character:
|
|
dict_character.append("<td></td>")
|
|
if "<td>" in dict_character:
|
|
dict_character.remove("<td>")
|
|
|
|
dict_character = self.add_special_char(dict_character)
|
|
self.dict = {}
|
|
for i, char in enumerate(dict_character):
|
|
self.dict[char] = i
|
|
self.character = dict_character
|
|
self.td_token = ["<td>", "<td", "<td></td>"]
|
|
|
|
def __call__(self, preds, batch=None):
|
|
structure_probs = preds["structure_probs"]
|
|
bbox_preds = preds["loc_preds"]
|
|
shape_list = batch[-1]
|
|
result = self.decode(structure_probs, bbox_preds, shape_list)
|
|
if len(batch) == 1: # only contains shape
|
|
return result
|
|
|
|
label_decode_result = self.decode_label(batch)
|
|
return result, label_decode_result
|
|
|
|
def decode(self, structure_probs, bbox_preds, shape_list):
|
|
"""convert text-label into text-index."""
|
|
ignored_tokens = self.get_ignored_tokens()
|
|
end_idx = self.dict[self.end_str]
|
|
|
|
structure_idx = structure_probs.argmax(axis=2)
|
|
structure_probs = structure_probs.max(axis=2)
|
|
|
|
structure_batch_list = []
|
|
bbox_batch_list = []
|
|
batch_size = len(structure_idx)
|
|
for batch_idx in range(batch_size):
|
|
structure_list = []
|
|
bbox_list = []
|
|
score_list = []
|
|
for idx in range(len(structure_idx[batch_idx])):
|
|
char_idx = int(structure_idx[batch_idx][idx])
|
|
if idx > 0 and char_idx == end_idx:
|
|
break
|
|
|
|
if char_idx in ignored_tokens:
|
|
continue
|
|
|
|
text = self.character[char_idx]
|
|
if text in self.td_token:
|
|
bbox = bbox_preds[batch_idx, idx]
|
|
bbox = self._bbox_decode(bbox, shape_list[batch_idx])
|
|
bbox_list.append(bbox)
|
|
structure_list.append(text)
|
|
score_list.append(structure_probs[batch_idx, idx])
|
|
structure_batch_list.append([structure_list, np.mean(score_list)])
|
|
bbox_batch_list.append(np.array(bbox_list))
|
|
result = {
|
|
"bbox_batch_list": bbox_batch_list,
|
|
"structure_batch_list": structure_batch_list,
|
|
}
|
|
return result
|
|
|
|
def decode_label(self, batch):
|
|
"""convert text-label into text-index."""
|
|
structure_idx = batch[1]
|
|
gt_bbox_list = batch[2]
|
|
shape_list = batch[-1]
|
|
ignored_tokens = self.get_ignored_tokens()
|
|
end_idx = self.dict[self.end_str]
|
|
|
|
structure_batch_list = []
|
|
bbox_batch_list = []
|
|
batch_size = len(structure_idx)
|
|
for batch_idx in range(batch_size):
|
|
structure_list = []
|
|
bbox_list = []
|
|
for idx in range(len(structure_idx[batch_idx])):
|
|
char_idx = int(structure_idx[batch_idx][idx])
|
|
if idx > 0 and char_idx == end_idx:
|
|
break
|
|
|
|
if char_idx in ignored_tokens:
|
|
continue
|
|
|
|
structure_list.append(self.character[char_idx])
|
|
|
|
bbox = gt_bbox_list[batch_idx][idx]
|
|
if bbox.sum() != 0:
|
|
bbox = self._bbox_decode(bbox, shape_list[batch_idx])
|
|
bbox_list.append(bbox)
|
|
|
|
structure_batch_list.append(structure_list)
|
|
bbox_batch_list.append(bbox_list)
|
|
result = {
|
|
"bbox_batch_list": bbox_batch_list,
|
|
"structure_batch_list": structure_batch_list,
|
|
}
|
|
return result
|
|
|
|
def _bbox_decode(self, bbox, shape):
|
|
h, w = shape[:2]
|
|
bbox[0::2] *= w
|
|
bbox[1::2] *= h
|
|
return bbox
|
|
|
|
def get_ignored_tokens(self):
|
|
beg_idx = self.get_beg_end_flag_idx("beg")
|
|
end_idx = self.get_beg_end_flag_idx("end")
|
|
return [beg_idx, end_idx]
|
|
|
|
def get_beg_end_flag_idx(self, beg_or_end):
|
|
if beg_or_end == "beg":
|
|
return np.array(self.dict[self.beg_str])
|
|
|
|
if beg_or_end == "end":
|
|
return np.array(self.dict[self.end_str])
|
|
|
|
raise TypeError(f"unsupport type {beg_or_end} in get_beg_end_flag_idx")
|
|
|
|
def add_special_char(self, dict_character):
|
|
self.beg_str = "sos"
|
|
self.end_str = "eos"
|
|
dict_character = [self.beg_str] + dict_character + [self.end_str]
|
|
return dict_character
|
|
|
|
|
|
class TablePreprocess:
|
|
def __init__(self):
|
|
self.table_max_len = 488
|
|
self.build_pre_process_list()
|
|
self.ops = self.create_operators()
|
|
|
|
def __call__(self, data):
|
|
"""transform"""
|
|
if self.ops is None:
|
|
self.ops = []
|
|
|
|
for op in self.ops:
|
|
data = op(data)
|
|
if data is None:
|
|
return None
|
|
return data
|
|
|
|
def create_operators(
|
|
self,
|
|
):
|
|
"""
|
|
create operators based on the config
|
|
|
|
Args:
|
|
params(list): a dict list, used to create some operators
|
|
"""
|
|
assert isinstance(
|
|
self.pre_process_list, list
|
|
), "operator config should be a list"
|
|
ops = []
|
|
for operator in self.pre_process_list:
|
|
assert (
|
|
isinstance(operator, dict) and len(operator) == 1
|
|
), "yaml format error"
|
|
op_name = list(operator)[0]
|
|
param = {} if operator[op_name] is None else operator[op_name]
|
|
op = eval(op_name)(**param)
|
|
ops.append(op)
|
|
return ops
|
|
|
|
def build_pre_process_list(self):
|
|
resize_op = {
|
|
"ResizeTableImage": {
|
|
"max_len": self.table_max_len,
|
|
}
|
|
}
|
|
pad_op = {
|
|
"PaddingTableImage": {"size": [self.table_max_len, self.table_max_len]}
|
|
}
|
|
normalize_op = {
|
|
"NormalizeImage": {
|
|
"std": [0.229, 0.224, 0.225],
|
|
"mean": [0.485, 0.456, 0.406],
|
|
"scale": "1./255.",
|
|
"order": "hwc",
|
|
}
|
|
}
|
|
to_chw_op = {"ToCHWImage": None}
|
|
keep_keys_op = {"KeepKeys": {"keep_keys": ["image", "shape"]}}
|
|
self.pre_process_list = [
|
|
resize_op,
|
|
normalize_op,
|
|
pad_op,
|
|
to_chw_op,
|
|
keep_keys_op,
|
|
]
|
|
|
|
|
|
class ResizeTableImage:
|
|
def __init__(self, max_len, resize_bboxes=False, infer_mode=False):
|
|
super(ResizeTableImage, self).__init__()
|
|
self.max_len = max_len
|
|
self.resize_bboxes = resize_bboxes
|
|
self.infer_mode = infer_mode
|
|
|
|
def __call__(self, data):
|
|
img = data["image"]
|
|
height, width = img.shape[0:2]
|
|
ratio = self.max_len / (max(height, width) * 1.0)
|
|
resize_h = int(height * ratio)
|
|
resize_w = int(width * ratio)
|
|
resize_img = cv2.resize(img, (resize_w, resize_h))
|
|
if self.resize_bboxes and not self.infer_mode:
|
|
data["bboxes"] = data["bboxes"] * ratio
|
|
data["image"] = resize_img
|
|
data["src_img"] = img
|
|
data["shape"] = np.array([height, width, ratio, ratio])
|
|
data["max_len"] = self.max_len
|
|
return data
|
|
|
|
|
|
class PaddingTableImage:
|
|
def __init__(self, size, **kwargs):
|
|
super(PaddingTableImage, self).__init__()
|
|
self.size = size
|
|
|
|
def __call__(self, data):
|
|
img = data["image"]
|
|
pad_h, pad_w = self.size
|
|
padding_img = np.zeros((pad_h, pad_w, 3), dtype=np.float32)
|
|
height, width = img.shape[0:2]
|
|
padding_img[0:height, 0:width, :] = img.copy()
|
|
data["image"] = padding_img
|
|
shape = data["shape"].tolist()
|
|
shape.extend([pad_h, pad_w])
|
|
data["shape"] = np.array(shape)
|
|
return data
|
|
|
|
|
|
class NormalizeImage:
|
|
"""normalize image such as substract mean, divide std"""
|
|
|
|
def __init__(self, scale=None, mean=None, std=None, order="chw", **kwargs):
|
|
if isinstance(scale, str):
|
|
scale = eval(scale)
|
|
self.scale = np.float32(scale if scale is not None else 1.0 / 255.0)
|
|
mean = mean if mean is not None else [0.485, 0.456, 0.406]
|
|
std = std if std is not None else [0.229, 0.224, 0.225]
|
|
|
|
shape = (3, 1, 1) if order == "chw" else (1, 1, 3)
|
|
self.mean = np.array(mean).reshape(shape).astype("float32")
|
|
self.std = np.array(std).reshape(shape).astype("float32")
|
|
|
|
def __call__(self, data):
|
|
img = np.array(data["image"])
|
|
assert isinstance(img, np.ndarray), "invalid input 'img' in NormalizeImage"
|
|
data["image"] = (img.astype("float32") * self.scale - self.mean) / self.std
|
|
return data
|
|
|
|
|
|
class ToCHWImage:
|
|
"""convert hwc image to chw image"""
|
|
|
|
def __init__(self, **kwargs):
|
|
pass
|
|
|
|
def __call__(self, data):
|
|
img = np.array(data["image"])
|
|
data["image"] = img.transpose((2, 0, 1))
|
|
return data
|
|
|
|
|
|
class KeepKeys:
|
|
def __init__(self, keep_keys, **kwargs):
|
|
self.keep_keys = keep_keys
|
|
|
|
def __call__(self, data):
|
|
data_list = []
|
|
for key in self.keep_keys:
|
|
data_list.append(data[key])
|
|
return data_list
|
|
|
|
|
|
def trans_char_ocr_res(ocr_res):
|
|
word_result = []
|
|
for res in ocr_res:
|
|
score = res[2]
|
|
for word_box, word in zip(res[3], res[4]):
|
|
word_res = []
|
|
word_res.append(word_box)
|
|
word_res.append(word)
|
|
word_res.append(score)
|
|
word_result.append(word_res)
|
|
return word_result
|