# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # -*- encoding: utf-8 -*- # @Author: SWHL # @Contact: liekkaskono@163.com import os import platform import traceback from enum import Enum from pathlib import Path from typing import Any, Dict, List, Tuple, Union import cv2 import numpy as np from onnxruntime import ( GraphOptimizationLevel, InferenceSession, SessionOptions, get_available_providers, get_device, ) from rapid_table.utils import Logger class EP(Enum): CPU_EP = "CPUExecutionProvider" CUDA_EP = "CUDAExecutionProvider" DIRECTML_EP = "DmlExecutionProvider" class OrtInferSession: def __init__(self, config: Dict[str, Any]): self.logger = Logger(logger_name=__name__).get_log() model_path = config.get("model_path", None) self._verify_model(model_path) self.cfg_use_cuda = config.get("use_cuda", None) self.cfg_use_dml = config.get("use_dml", None) self.had_providers: List[str] = get_available_providers() EP_list = self._get_ep_list() sess_opt = self._init_sess_opts(config) self.session = InferenceSession( model_path, sess_options=sess_opt, providers=EP_list, ) self._verify_providers() @staticmethod def _init_sess_opts(config: Dict[str, Any]) -> SessionOptions: sess_opt = SessionOptions() sess_opt.log_severity_level = 4 sess_opt.enable_cpu_mem_arena = False sess_opt.graph_optimization_level = GraphOptimizationLevel.ORT_ENABLE_ALL cpu_nums = os.cpu_count() intra_op_num_threads = config.get("intra_op_num_threads", -1) if intra_op_num_threads != -1 and 1 <= intra_op_num_threads <= cpu_nums: sess_opt.intra_op_num_threads = intra_op_num_threads inter_op_num_threads = config.get("inter_op_num_threads", -1) if inter_op_num_threads != -1 and 1 <= inter_op_num_threads <= cpu_nums: sess_opt.inter_op_num_threads = inter_op_num_threads return sess_opt def get_metadata(self, key: str = "character") -> list: meta_dict = self.session.get_modelmeta().custom_metadata_map content_list = meta_dict[key].splitlines() return content_list def _get_ep_list(self) -> List[Tuple[str, Dict[str, Any]]]: cpu_provider_opts = { "arena_extend_strategy": "kSameAsRequested", } EP_list = [(EP.CPU_EP.value, cpu_provider_opts)] cuda_provider_opts = { "device_id": 0, "arena_extend_strategy": "kNextPowerOfTwo", "cudnn_conv_algo_search": "EXHAUSTIVE", "do_copy_in_default_stream": True, } self.use_cuda = self._check_cuda() if self.use_cuda: EP_list.insert(0, (EP.CUDA_EP.value, cuda_provider_opts)) self.use_directml = self._check_dml() if self.use_directml: self.logger.info( "Windows 10 or above detected, try to use DirectML as primary provider" ) directml_options = ( cuda_provider_opts if self.use_cuda else cpu_provider_opts ) EP_list.insert(0, (EP.DIRECTML_EP.value, directml_options)) return EP_list def _check_cuda(self) -> bool: if not self.cfg_use_cuda: return False cur_device = get_device() if cur_device == "GPU" and EP.CUDA_EP.value in self.had_providers: return True self.logger.warning( "%s is not in available providers (%s). Use %s inference by default.", EP.CUDA_EP.value, self.had_providers, self.had_providers[0], ) self.logger.info("!!!Recommend to use rapidocr_paddle for inference on GPU.") self.logger.info( "(For reference only) If you want to use GPU acceleration, you must do:" ) self.logger.info( "First, uninstall all onnxruntime pakcages in current environment." ) self.logger.info( "Second, install onnxruntime-gpu by `pip install onnxruntime-gpu`." ) self.logger.info( "\tNote the onnxruntime-gpu version must match your cuda and cudnn version." ) self.logger.info( "\tYou can refer this link: https://onnxruntime.ai/docs/execution-providers/CUDA-EP.html" ) self.logger.info( "Third, ensure %s is in available providers list. e.g. ['CUDAExecutionProvider', 'CPUExecutionProvider']", EP.CUDA_EP.value, ) return False def _check_dml(self) -> bool: if not self.cfg_use_dml: return False cur_os = platform.system() if cur_os != "Windows": self.logger.warning( "DirectML is only supported in Windows OS. The current OS is %s. Use %s inference by default.", cur_os, self.had_providers[0], ) return False cur_window_version = int(platform.release().split(".")[0]) if cur_window_version < 10: self.logger.warning( "DirectML is only supported in Windows 10 and above OS. The current Windows version is %s. Use %s inference by default.", cur_window_version, self.had_providers[0], ) return False if EP.DIRECTML_EP.value in self.had_providers: return True self.logger.warning( "%s is not in available providers (%s). Use %s inference by default.", EP.DIRECTML_EP.value, self.had_providers, self.had_providers[0], ) self.logger.info("If you want to use DirectML acceleration, you must do:") self.logger.info( "First, uninstall all onnxruntime pakcages in current environment." ) self.logger.info( "Second, install onnxruntime-directml by `pip install onnxruntime-directml`" ) self.logger.info( "Third, ensure %s is in available providers list. e.g. ['DmlExecutionProvider', 'CPUExecutionProvider']", EP.DIRECTML_EP.value, ) return False def _verify_providers(self): session_providers = self.session.get_providers() first_provider = session_providers[0] if self.use_cuda and first_provider != EP.CUDA_EP.value: self.logger.warning( "%s is not avaiable for current env, the inference part is automatically shifted to be executed under %s.", EP.CUDA_EP.value, first_provider, ) if self.use_directml and first_provider != EP.DIRECTML_EP.value: self.logger.warning( "%s is not available for current env, the inference part is automatically shifted to be executed under %s.", EP.DIRECTML_EP.value, first_provider, ) def __call__(self, input_content: List[np.ndarray]) -> np.ndarray: input_dict = dict(zip(self.get_input_names(), input_content)) try: return self.session.run(None, input_dict) except Exception as e: error_info = traceback.format_exc() raise ONNXRuntimeError(error_info) from e def get_input_names(self) -> List[str]: return [v.name for v in self.session.get_inputs()] def get_output_names(self) -> List[str]: return [v.name for v in self.session.get_outputs()] def get_character_list(self, key: str = "character") -> List[str]: meta_dict = self.session.get_modelmeta().custom_metadata_map return meta_dict[key].splitlines() def have_key(self, key: str = "character") -> bool: meta_dict = self.session.get_modelmeta().custom_metadata_map if key in meta_dict.keys(): return True return False @staticmethod def _verify_model(model_path: Union[str, Path, None]): if model_path is None: raise ValueError("model_path is None!") model_path = Path(model_path) if not model_path.exists(): raise FileNotFoundError(f"{model_path} does not exists.") if not model_path.is_file(): raise FileExistsError(f"{model_path} is not a file.") class ONNXRuntimeError(Exception): pass class TableLabelDecode: def __init__(self, dict_character, merge_no_span_structure=True, **kwargs): if merge_no_span_structure: if "