# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import os import sys __dir__ = os.path.dirname(os.path.abspath(__file__)) sys.path.append(__dir__) sys.path.insert(0, os.path.abspath(os.path.join(__dir__, "../.."))) os.environ["FLAGS_allocator_strategy"] = "auto_growth" import cv2 import numpy as np import time import sys import tools.infer.utility as utility from ppocr.utils.logging import get_logger from ppocr.utils.utility import get_image_file_list, check_and_read from ppocr.data import create_operators, transform from ppocr.postprocess import build_post_process logger = get_logger() class TextE2E(object): def __init__(self, args): self.args = args self.e2e_algorithm = args.e2e_algorithm self.use_onnx = args.use_onnx pre_process_list = [ {"E2EResizeForTest": {}}, { "NormalizeImage": { "std": [0.229, 0.224, 0.225], "mean": [0.485, 0.456, 0.406], "scale": "1./255.", "order": "hwc", } }, {"ToCHWImage": None}, {"KeepKeys": {"keep_keys": ["image", "shape"]}}, ] postprocess_params = {} if self.e2e_algorithm == "PGNet": pre_process_list[0] = { "E2EResizeForTest": { "max_side_len": args.e2e_limit_side_len, "valid_set": "totaltext", } } postprocess_params["name"] = "PGPostProcess" postprocess_params["score_thresh"] = args.e2e_pgnet_score_thresh postprocess_params["character_dict_path"] = args.e2e_char_dict_path postprocess_params["valid_set"] = args.e2e_pgnet_valid_set postprocess_params["mode"] = args.e2e_pgnet_mode else: logger.info("unknown e2e_algorithm:{}".format(self.e2e_algorithm)) sys.exit(0) self.preprocess_op = create_operators(pre_process_list) self.postprocess_op = build_post_process(postprocess_params) ( self.predictor, self.input_tensor, self.output_tensors, _, ) = utility.create_predictor( args, "e2e", logger ) # paddle.jit.load(args.det_model_dir) # self.predictor.eval() def clip_det_res(self, points, img_height, img_width): for pno in range(points.shape[0]): points[pno, 0] = int(min(max(points[pno, 0], 0), img_width - 1)) points[pno, 1] = int(min(max(points[pno, 1], 0), img_height - 1)) return points def filter_tag_det_res_only_clip(self, dt_boxes, image_shape): img_height, img_width = image_shape[0:2] dt_boxes_new = [] for box in dt_boxes: box = self.clip_det_res(box, img_height, img_width) dt_boxes_new.append(box) dt_boxes = np.array(dt_boxes_new) return dt_boxes def __call__(self, img): ori_im = img.copy() data = {"image": img} data = transform(data, self.preprocess_op) img, shape_list = data if img is None: return None, 0 img = np.expand_dims(img, axis=0) shape_list = np.expand_dims(shape_list, axis=0) img = img.copy() starttime = time.time() if self.use_onnx: input_dict = {} input_dict[self.input_tensor.name] = img outputs = self.predictor.run(self.output_tensors, input_dict) preds = {} preds["f_border"] = outputs[0] preds["f_char"] = outputs[1] preds["f_direction"] = outputs[2] preds["f_score"] = outputs[3] else: self.input_tensor.copy_from_cpu(img) self.predictor.run() outputs = [] for output_tensor in self.output_tensors: output = output_tensor.copy_to_cpu() outputs.append(output) preds = {} if self.e2e_algorithm == "PGNet": preds["f_border"] = outputs[0] preds["f_char"] = outputs[1] preds["f_direction"] = outputs[2] preds["f_score"] = outputs[3] else: raise NotImplementedError post_result = self.postprocess_op(preds, shape_list) points, strs = post_result["points"], post_result["texts"] dt_boxes = self.filter_tag_det_res_only_clip(points, ori_im.shape) elapse = time.time() - starttime return dt_boxes, strs, elapse if __name__ == "__main__": args = utility.parse_args() image_file_list = get_image_file_list(args.image_dir) text_detector = TextE2E(args) count = 0 total_time = 0 draw_img_save = "./inference_results" if not os.path.exists(draw_img_save): os.makedirs(draw_img_save) for image_file in image_file_list: img, flag, _ = check_and_read(image_file) if not flag: img = cv2.imread(image_file) if img is None: logger.info("error in loading image:{}".format(image_file)) continue points, strs, elapse = text_detector(img) if count > 0: total_time += elapse count += 1 logger.info("Predict time of {}: {}".format(image_file, elapse)) src_im = utility.draw_e2e_res(points, strs, image_file) img_name_pure = os.path.split(image_file)[-1] img_path = os.path.join(draw_img_save, "e2e_res_{}".format(img_name_pure)) cv2.imwrite(img_path, src_im) logger.info("The visualized image saved in {}".format(img_path)) if count > 1: logger.info("Avg Time: {}".format(total_time / (count - 1)))