# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # -*- encoding: utf-8 -*- # @Author: SWHL # @Contact: liekkaskono@163.com import os import platform import traceback from enum import Enum from pathlib import Path from typing import Any, Dict, List, Tuple, Union import cv2 import numpy as np from onnxruntime import ( GraphOptimizationLevel, InferenceSession, SessionOptions, get_available_providers, get_device, ) from rapid_table.utils import Logger class EP(Enum): CPU_EP = "CPUExecutionProvider" CUDA_EP = "CUDAExecutionProvider" DIRECTML_EP = "DmlExecutionProvider" class OrtInferSession: def __init__(self, config: Dict[str, Any]): self.logger = Logger(logger_name=__name__).get_log() model_path = config.get("model_path", None) self._verify_model(model_path) self.cfg_use_cuda = config.get("use_cuda", None) self.cfg_use_dml = config.get("use_dml", None) self.had_providers: List[str] = get_available_providers() EP_list = self._get_ep_list() sess_opt = self._init_sess_opts(config) self.session = InferenceSession( model_path, sess_options=sess_opt, providers=EP_list, ) self._verify_providers() @staticmethod def _init_sess_opts(config: Dict[str, Any]) -> SessionOptions: sess_opt = SessionOptions() sess_opt.log_severity_level = 4 sess_opt.enable_cpu_mem_arena = False sess_opt.graph_optimization_level = GraphOptimizationLevel.ORT_ENABLE_ALL cpu_nums = os.cpu_count() intra_op_num_threads = config.get("intra_op_num_threads", -1) if intra_op_num_threads != -1 and 1 <= intra_op_num_threads <= cpu_nums: sess_opt.intra_op_num_threads = intra_op_num_threads inter_op_num_threads = config.get("inter_op_num_threads", -1) if inter_op_num_threads != -1 and 1 <= inter_op_num_threads <= cpu_nums: sess_opt.inter_op_num_threads = inter_op_num_threads return sess_opt def get_metadata(self, key: str = "character") -> list: meta_dict = self.session.get_modelmeta().custom_metadata_map content_list = meta_dict[key].splitlines() return content_list def _get_ep_list(self) -> List[Tuple[str, Dict[str, Any]]]: cpu_provider_opts = { "arena_extend_strategy": "kSameAsRequested", } EP_list = [(EP.CPU_EP.value, cpu_provider_opts)] cuda_provider_opts = { "device_id": 0, "arena_extend_strategy": "kNextPowerOfTwo", "cudnn_conv_algo_search": "EXHAUSTIVE", "do_copy_in_default_stream": True, } self.use_cuda = self._check_cuda() if self.use_cuda: EP_list.insert(0, (EP.CUDA_EP.value, cuda_provider_opts)) self.use_directml = self._check_dml() if self.use_directml: self.logger.info( "Windows 10 or above detected, try to use DirectML as primary provider" ) directml_options = ( cuda_provider_opts if self.use_cuda else cpu_provider_opts ) EP_list.insert(0, (EP.DIRECTML_EP.value, directml_options)) return EP_list def _check_cuda(self) -> bool: if not self.cfg_use_cuda: return False cur_device = get_device() if cur_device == "GPU" and EP.CUDA_EP.value in self.had_providers: return True self.logger.warning( "%s is not in available providers (%s). Use %s inference by default.", EP.CUDA_EP.value, self.had_providers, self.had_providers[0], ) self.logger.info("!!!Recommend to use rapidocr_paddle for inference on GPU.") self.logger.info( "(For reference only) If you want to use GPU acceleration, you must do:" ) self.logger.info( "First, uninstall all onnxruntime pakcages in current environment." ) self.logger.info( "Second, install onnxruntime-gpu by `pip install onnxruntime-gpu`." ) self.logger.info( "\tNote the onnxruntime-gpu version must match your cuda and cudnn version." ) self.logger.info( "\tYou can refer this link: https://onnxruntime.ai/docs/execution-providers/CUDA-EP.html" ) self.logger.info( "Third, ensure %s is in available providers list. e.g. ['CUDAExecutionProvider', 'CPUExecutionProvider']", EP.CUDA_EP.value, ) return False def _check_dml(self) -> bool: if not self.cfg_use_dml: return False cur_os = platform.system() if cur_os != "Windows": self.logger.warning( "DirectML is only supported in Windows OS. The current OS is %s. Use %s inference by default.", cur_os, self.had_providers[0], ) return False cur_window_version = int(platform.release().split(".")[0]) if cur_window_version < 10: self.logger.warning( "DirectML is only supported in Windows 10 and above OS. The current Windows version is %s. Use %s inference by default.", cur_window_version, self.had_providers[0], ) return False if EP.DIRECTML_EP.value in self.had_providers: return True self.logger.warning( "%s is not in available providers (%s). Use %s inference by default.", EP.DIRECTML_EP.value, self.had_providers, self.had_providers[0], ) self.logger.info("If you want to use DirectML acceleration, you must do:") self.logger.info( "First, uninstall all onnxruntime pakcages in current environment." ) self.logger.info( "Second, install onnxruntime-directml by `pip install onnxruntime-directml`" ) self.logger.info( "Third, ensure %s is in available providers list. e.g. ['DmlExecutionProvider', 'CPUExecutionProvider']", EP.DIRECTML_EP.value, ) return False def _verify_providers(self): session_providers = self.session.get_providers() first_provider = session_providers[0] if self.use_cuda and first_provider != EP.CUDA_EP.value: self.logger.warning( "%s is not avaiable for current env, the inference part is automatically shifted to be executed under %s.", EP.CUDA_EP.value, first_provider, ) if self.use_directml and first_provider != EP.DIRECTML_EP.value: self.logger.warning( "%s is not available for current env, the inference part is automatically shifted to be executed under %s.", EP.DIRECTML_EP.value, first_provider, ) def __call__(self, input_content: List[np.ndarray]) -> np.ndarray: input_dict = dict(zip(self.get_input_names(), input_content)) try: return self.session.run(None, input_dict) except Exception as e: error_info = traceback.format_exc() raise ONNXRuntimeError(error_info) from e def get_input_names(self) -> List[str]: return [v.name for v in self.session.get_inputs()] def get_output_names(self) -> List[str]: return [v.name for v in self.session.get_outputs()] def get_character_list(self, key: str = "character") -> List[str]: meta_dict = self.session.get_modelmeta().custom_metadata_map return meta_dict[key].splitlines() def have_key(self, key: str = "character") -> bool: meta_dict = self.session.get_modelmeta().custom_metadata_map if key in meta_dict.keys(): return True return False @staticmethod def _verify_model(model_path: Union[str, Path, None]): if model_path is None: raise ValueError("model_path is None!") model_path = Path(model_path) if not model_path.exists(): raise FileNotFoundError(f"{model_path} does not exists.") if not model_path.is_file(): raise FileExistsError(f"{model_path} is not a file.") class ONNXRuntimeError(Exception): pass class TableLabelDecode: def __init__(self, dict_character, merge_no_span_structure=True, **kwargs): if merge_no_span_structure: if "" not in dict_character: dict_character.append("") if "" in dict_character: dict_character.remove("") dict_character = self.add_special_char(dict_character) self.dict = {} for i, char in enumerate(dict_character): self.dict[char] = i self.character = dict_character self.td_token = ["", ""] def __call__(self, preds, batch=None): structure_probs = preds["structure_probs"] bbox_preds = preds["loc_preds"] shape_list = batch[-1] result = self.decode(structure_probs, bbox_preds, shape_list) if len(batch) == 1: # only contains shape return result label_decode_result = self.decode_label(batch) return result, label_decode_result def decode(self, structure_probs, bbox_preds, shape_list): """convert text-label into text-index.""" ignored_tokens = self.get_ignored_tokens() end_idx = self.dict[self.end_str] structure_idx = structure_probs.argmax(axis=2) structure_probs = structure_probs.max(axis=2) structure_batch_list = [] bbox_batch_list = [] batch_size = len(structure_idx) for batch_idx in range(batch_size): structure_list = [] bbox_list = [] score_list = [] for idx in range(len(structure_idx[batch_idx])): char_idx = int(structure_idx[batch_idx][idx]) if idx > 0 and char_idx == end_idx: break if char_idx in ignored_tokens: continue text = self.character[char_idx] if text in self.td_token: bbox = bbox_preds[batch_idx, idx] bbox = self._bbox_decode(bbox, shape_list[batch_idx]) bbox_list.append(bbox) structure_list.append(text) score_list.append(structure_probs[batch_idx, idx]) structure_batch_list.append([structure_list, np.mean(score_list)]) bbox_batch_list.append(np.array(bbox_list)) result = { "bbox_batch_list": bbox_batch_list, "structure_batch_list": structure_batch_list, } return result def decode_label(self, batch): """convert text-label into text-index.""" structure_idx = batch[1] gt_bbox_list = batch[2] shape_list = batch[-1] ignored_tokens = self.get_ignored_tokens() end_idx = self.dict[self.end_str] structure_batch_list = [] bbox_batch_list = [] batch_size = len(structure_idx) for batch_idx in range(batch_size): structure_list = [] bbox_list = [] for idx in range(len(structure_idx[batch_idx])): char_idx = int(structure_idx[batch_idx][idx]) if idx > 0 and char_idx == end_idx: break if char_idx in ignored_tokens: continue structure_list.append(self.character[char_idx]) bbox = gt_bbox_list[batch_idx][idx] if bbox.sum() != 0: bbox = self._bbox_decode(bbox, shape_list[batch_idx]) bbox_list.append(bbox) structure_batch_list.append(structure_list) bbox_batch_list.append(bbox_list) result = { "bbox_batch_list": bbox_batch_list, "structure_batch_list": structure_batch_list, } return result def _bbox_decode(self, bbox, shape): h, w = shape[:2] bbox[0::2] *= w bbox[1::2] *= h return bbox def get_ignored_tokens(self): beg_idx = self.get_beg_end_flag_idx("beg") end_idx = self.get_beg_end_flag_idx("end") return [beg_idx, end_idx] def get_beg_end_flag_idx(self, beg_or_end): if beg_or_end == "beg": return np.array(self.dict[self.beg_str]) if beg_or_end == "end": return np.array(self.dict[self.end_str]) raise TypeError(f"unsupport type {beg_or_end} in get_beg_end_flag_idx") def add_special_char(self, dict_character): self.beg_str = "sos" self.end_str = "eos" dict_character = [self.beg_str] + dict_character + [self.end_str] return dict_character class TablePreprocess: def __init__(self): self.table_max_len = 488 self.build_pre_process_list() self.ops = self.create_operators() def __call__(self, data): """transform""" if self.ops is None: self.ops = [] for op in self.ops: data = op(data) if data is None: return None return data def create_operators( self, ): """ create operators based on the config Args: params(list): a dict list, used to create some operators """ assert isinstance( self.pre_process_list, list ), "operator config should be a list" ops = [] for operator in self.pre_process_list: assert ( isinstance(operator, dict) and len(operator) == 1 ), "yaml format error" op_name = list(operator)[0] param = {} if operator[op_name] is None else operator[op_name] op = eval(op_name)(**param) ops.append(op) return ops def build_pre_process_list(self): resize_op = { "ResizeTableImage": { "max_len": self.table_max_len, } } pad_op = { "PaddingTableImage": {"size": [self.table_max_len, self.table_max_len]} } normalize_op = { "NormalizeImage": { "std": [0.229, 0.224, 0.225], "mean": [0.485, 0.456, 0.406], "scale": "1./255.", "order": "hwc", } } to_chw_op = {"ToCHWImage": None} keep_keys_op = {"KeepKeys": {"keep_keys": ["image", "shape"]}} self.pre_process_list = [ resize_op, normalize_op, pad_op, to_chw_op, keep_keys_op, ] class ResizeTableImage: def __init__(self, max_len, resize_bboxes=False, infer_mode=False): super(ResizeTableImage, self).__init__() self.max_len = max_len self.resize_bboxes = resize_bboxes self.infer_mode = infer_mode def __call__(self, data): img = data["image"] height, width = img.shape[0:2] ratio = self.max_len / (max(height, width) * 1.0) resize_h = int(height * ratio) resize_w = int(width * ratio) resize_img = cv2.resize(img, (resize_w, resize_h)) if self.resize_bboxes and not self.infer_mode: data["bboxes"] = data["bboxes"] * ratio data["image"] = resize_img data["src_img"] = img data["shape"] = np.array([height, width, ratio, ratio]) data["max_len"] = self.max_len return data class PaddingTableImage: def __init__(self, size, **kwargs): super(PaddingTableImage, self).__init__() self.size = size def __call__(self, data): img = data["image"] pad_h, pad_w = self.size padding_img = np.zeros((pad_h, pad_w, 3), dtype=np.float32) height, width = img.shape[0:2] padding_img[0:height, 0:width, :] = img.copy() data["image"] = padding_img shape = data["shape"].tolist() shape.extend([pad_h, pad_w]) data["shape"] = np.array(shape) return data class NormalizeImage: """normalize image such as substract mean, divide std""" def __init__(self, scale=None, mean=None, std=None, order="chw", **kwargs): if isinstance(scale, str): scale = eval(scale) self.scale = np.float32(scale if scale is not None else 1.0 / 255.0) mean = mean if mean is not None else [0.485, 0.456, 0.406] std = std if std is not None else [0.229, 0.224, 0.225] shape = (3, 1, 1) if order == "chw" else (1, 1, 3) self.mean = np.array(mean).reshape(shape).astype("float32") self.std = np.array(std).reshape(shape).astype("float32") def __call__(self, data): img = np.array(data["image"]) assert isinstance(img, np.ndarray), "invalid input 'img' in NormalizeImage" data["image"] = (img.astype("float32") * self.scale - self.mean) / self.std return data class ToCHWImage: """convert hwc image to chw image""" def __init__(self, **kwargs): pass def __call__(self, data): img = np.array(data["image"]) data["image"] = img.transpose((2, 0, 1)) return data class KeepKeys: def __init__(self, keep_keys, **kwargs): self.keep_keys = keep_keys def __call__(self, data): data_list = [] for key in self.keep_keys: data_list.append(data[key]) return data_list def trans_char_ocr_res(ocr_res): word_result = [] for res in ocr_res: score = res[2] for word_box, word in zip(res[3], res[4]): word_res = [] word_res.append(word_box) word_res.append(word) word_res.append(score) word_result.append(word_res) return word_result