You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

173 lines
5.6 KiB
Python

import os
import sys
import subprocess
import cv2
import copy
import numpy as np
from PIL import Image
import tools.infer.utility as utility
import tools.infer.predict_rec as predict_rec
import tools.infer.predict_det as predict_det
import tools.infer.predict_cls as predict_cls
from ppocr.utils.utility import get_image_file_list, check_and_read
from tools.infer.utility import (
get_rotate_crop_image,
get_minarea_rect_crop,
slice_generator,
merge_fragmented,
)
class TextSystem(object):
def __init__(self, args):
self.text_detector = predict_det.TextDetector(args)
self.text_recognizer = predict_rec.TextRecognizer(args)
self.use_angle_cls = args.use_angle_cls
self.drop_score = args.drop_score
if self.use_angle_cls:
self.text_classifier = predict_cls.TextClassifier(args)
self.args = args
self.crop_image_res_index = 0
def draw_crop_rec_res(self, output_dir, img_crop_list, rec_res):
os.makedirs(output_dir, exist_ok=True)
bbox_num = len(img_crop_list)
for bno in range(bbox_num):
cv2.imwrite(
os.path.join(
output_dir, f"mg_crop_{bno + self.crop_image_res_index}.jpg"
),
img_crop_list[bno],
)
self.crop_image_res_index += bbox_num
def __call__(self, img, cls=True, slice={}):
if img is None:
return None, None, {}
ori_im = img.copy()
if slice:
slice_gen = slice_generator(
img,
horizontal_stride=slice["horizontal_stride"],
vertical_stride=slice["vertical_stride"],
)
elapsed = []
dt_slice_boxes = []
for slice_crop, v_start, h_start in slice_gen:
dt_boxes, elapse = self.text_detector(slice_crop, use_slice=True)
if dt_boxes.size:
dt_boxes[:, :, 0] += h_start
dt_boxes[:, :, 1] += v_start
dt_slice_boxes.append(dt_boxes)
elapsed.append(elapse)
dt_boxes = np.concatenate(dt_slice_boxes)
dt_boxes = merge_fragmented(
boxes=dt_boxes,
x_threshold=slice["merge_x_thres"],
y_threshold=slice["merge_y_thres"],
)
elapse = sum(elapsed)
else:
dt_boxes, elapse = self.text_detector(img)
if dt_boxes is None:
return None, None, {}
img_crop_list = []
dt_boxes = sorted_boxes(dt_boxes)
for bno in range(len(dt_boxes)):
tmp_box = copy.deepcopy(dt_boxes[bno])
if self.args.det_box_type == "quad":
img_crop = get_rotate_crop_image(ori_im, tmp_box)
else:
img_crop = get_minarea_rect_crop(ori_im, tmp_box)
img_crop_list.append(img_crop)
if self.use_angle_cls and cls:
img_crop_list, angle_list, elapse = self.text_classifier(img_crop_list)
if len(img_crop_list) > 1000:
pass
rec_res, elapse = self.text_recognizer(img_crop_list)
filter_boxes, filter_rec_res = [], []
for box, rec_result in zip(dt_boxes, rec_res):
text, score = rec_result[0], rec_result[1]
if score >= self.drop_score:
filter_boxes.append(box)
filter_rec_res.append(rec_result)
return filter_boxes, filter_rec_res, {}
def sorted_boxes(dt_boxes):
num_boxes = dt_boxes.shape[0]
sorted_boxes = sorted(dt_boxes, key=lambda x: (x[0][1], x[0][0]))
_boxes = list(sorted_boxes)
for i in range(num_boxes - 1):
for j in range(i, -1, -1):
if abs(_boxes[j + 1][0][1] - _boxes[j][0][1]) < 10 and (
_boxes[j + 1][0][0] < _boxes[j][0][0]
):
tmp = _boxes[j]
_boxes[j] = _boxes[j + 1]
_boxes[j + 1] = tmp
else:
break
return _boxes
def main(args):
image_file_list = get_image_file_list(args.image_dir)
image_file_list = image_file_list[args.process_id:: args.total_process_num]
text_sys = TextSystem(args)
# Warm-up (optional)
if args.warmup:
img = np.random.uniform(0, 255, [640, 640, 3]).astype(np.uint8)
for i in range(10):
text_sys(img)
for idx, image_file in enumerate(image_file_list):
img, flag_gif, flag_pdf = check_and_read(image_file)
if not flag_gif and not flag_pdf:
img = cv2.imread(image_file)
if not flag_pdf:
if img is None:
continue
imgs = [img]
else:
page_num = args.page_num
if page_num > len(img) or page_num == 0:
page_num = len(img)
imgs = img[:page_num]
for index, img in enumerate(imgs):
dt_boxes, rec_res, _ = text_sys(img)
# Output the recognized text
for text, _ in rec_res:
print(f"{text}")
if __name__ == "__main__":
args = utility.parse_args()
if args.use_mp:
p_list = []
total_process_num = args.total_process_num
for process_id in range(total_process_num):
cmd = (
[sys.executable, "-u"]
+ sys.argv
+ ["--process_id={}".format(process_id), "--use_mp={}".format(False)]
)
p = subprocess.Popen(cmd, stdout=sys.stdout, stderr=sys.stdout)
p_list.append(p)
for p in p_list:
p.wait()
else:
main(args)