You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

231 lines
8.3 KiB
Python

This file contains ambiguous Unicode characters!

This file contains ambiguous Unicode characters that may be confused with others in your current locale. If your use case is intentional and legitimate, you can safely ignore this warning. Use the Escape button to highlight these characters.

# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import time
import os
import ast
import argparse
from typing import List, Tuple
import numpy as np
class LayoutBox(object):
def __init__(self, clsid: int, pos: List[float], confidence: float):
self.clsid = clsid
self.pos = pos
self.confidence = confidence
class PageDetectionResult(object):
def __init__(self, boxes: List[LayoutBox], image: np.ndarray):
self.boxes = boxes
self.image = image
class Times(object):
def __init__(self):
self.time = 0.
# start time
self.st = 0.
# end time
self.et = 0.
def start(self):
self.st = time.time()
def end(self, repeats=1, accumulative=True):
self.et = time.time()
if accumulative:
self.time += (self.et - self.st) / repeats
else:
self.time = (self.et - self.st) / repeats
def reset(self):
self.time = 0.
self.st = 0.
self.et = 0.
def value(self):
return round(self.time, 4)
class Timer(Times):
def __init__(self, with_tracker=False):
super(Timer, self).__init__()
self.with_tracker = with_tracker
self.preprocess_time_s = Times()
self.inference_time_s = Times()
self.postprocess_time_s = Times()
self.tracking_time_s = Times()
self.img_num = 0
def info(self, average=False):
pre_time = self.preprocess_time_s.value()
infer_time = self.inference_time_s.value()
post_time = self.postprocess_time_s.value()
track_time = self.tracking_time_s.value()
total_time = pre_time + infer_time + post_time
if self.with_tracker:
total_time = total_time + track_time
total_time = round(total_time, 4)
print("------------------ Inference Time Info ----------------------")
print("total_time(ms): {}, img_num: {}".format(total_time * 1000,
self.img_num))
preprocess_time = round(pre_time / max(1, self.img_num),
4) if average else pre_time
postprocess_time = round(post_time / max(1, self.img_num),
4) if average else post_time
inference_time = round(infer_time / max(1, self.img_num),
4) if average else infer_time
tracking_time = round(track_time / max(1, self.img_num),
4) if average else track_time
average_latency = total_time / max(1, self.img_num)
qps = 0
if total_time > 0:
qps = 1 / average_latency
print("average latency time(ms): {:.2f}, QPS: {:2f}".format(
average_latency * 1000, qps))
if self.with_tracker:
print(
"preprocess_time(ms): {:.2f}, inference_time(ms): {:.2f}, postprocess_time(ms): {:.2f}, tracking_time(ms): {:.2f}".
format(preprocess_time * 1000, inference_time * 1000,
postprocess_time * 1000, tracking_time * 1000))
else:
print(
"preprocess_time(ms): {:.2f}, inference_time(ms): {:.2f}, postprocess_time(ms): {:.2f}".
format(preprocess_time * 1000, inference_time * 1000,
postprocess_time * 1000))
def report(self, average=False):
dic = {}
pre_time = self.preprocess_time_s.value()
infer_time = self.inference_time_s.value()
post_time = self.postprocess_time_s.value()
track_time = self.tracking_time_s.value()
dic['preprocess_time_s'] = round(pre_time / max(1, self.img_num),
4) if average else pre_time
dic['inference_time_s'] = round(infer_time / max(1, self.img_num),
4) if average else infer_time
dic['postprocess_time_s'] = round(post_time / max(1, self.img_num),
4) if average else post_time
dic['img_num'] = self.img_num
total_time = pre_time + infer_time + post_time
if self.with_tracker:
dic['tracking_time_s'] = round(track_time / max(1, self.img_num),
4) if average else track_time
total_time = total_time + track_time
dic['total_time_s'] = round(total_time, 4)
return dic
def iou(box1, box2):
"""计算两个框的 IoU交并比"""
x1 = max(box1[0], box2[0])
y1 = max(box1[1], box2[1])
x2 = min(box1[2], box2[2])
y2 = min(box1[3], box2[3])
inter_area = max(0, x2 - x1) * max(0, y2 - y1)
box1_area = (box1[2] - box1[0]) * (box1[3] - box1[1])
box2_area = (box2[2] - box2[0]) * (box2[3] - box2[1])
union_area = box1_area + box2_area - inter_area
return inter_area / union_area if union_area != 0 else 0
def non_max_suppression(boxes, scores, iou_threshold):
"""非极大值抑制"""
if not boxes:
return []
indices = np.argsort(scores)[::-1]
selected_boxes = []
while len(indices) > 0:
current = indices[0]
selected_boxes.append(current)
remaining = indices[1:]
filtered_indices = []
for i in remaining:
if iou(boxes[current], boxes[i]) <= iou_threshold:
filtered_indices.append(i)
indices = np.array(filtered_indices)
return selected_boxes
def is_box_inside(inner, outer):
"""判断inner box是否完全在outer box内"""
return (outer[0] <= inner[0] and outer[1] <= inner[1] and
outer[2] >= inner[2] and outer[3] >= inner[3])
def boxes_overlap(box1: List[float], box2: List[float]) -> bool:
"""判断两个框是否有交集"""
x1_max = max(box1[0], box2[0])
y1_max = max(box1[1], box2[1])
x2_min = min(box1[2], box2[2])
y2_min = min(box1[3], box2[3])
return x1_max < x2_min and y1_max < y2_min
def merge_boxes(boxes: List[List[float]]) -> List[float]:
x1 = min(box[0] for box in boxes)
y1 = min(box[1] for box in boxes)
x2 = max(box[2] for box in boxes)
y2 = max(box[3] for box in boxes)
return [x1, y1, x2, y2]
def merge_text_and_title_boxes(data: List[LayoutBox], merged_labels: Tuple[int], other_labels: Tuple[int], merged_box_label: int) -> List[LayoutBox]:
text_title_boxes = [(i, box) for i, box in enumerate(data) if box.clsid in merged_labels]
other_boxes = [box.pos for box in data if box.clsid in other_labels]
text_title_boxes.sort(key=lambda x: x[1].pos[1]) # sort by y1
merged = []
skip_indices = set()
i = 0
while i < len(text_title_boxes):
if i in skip_indices:
i += 1
continue
current_group = [text_title_boxes[i][1].pos]
group_confidences = [text_title_boxes[i][1].confidence]
j = i + 1
while j < len(text_title_boxes):
candidate_box = text_title_boxes[j][1].pos
tentative_merge = merge_boxes(current_group + [candidate_box])
has_intruder = any(boxes_overlap(other, tentative_merge) for other in other_boxes)
if has_intruder:
break
else:
current_group.append(candidate_box)
group_confidences.append(text_title_boxes[j][1].confidence)
skip_indices.add(j)
j += 1
if len(current_group) > 1:
merged_box = LayoutBox(merged_box_label, merge_boxes(current_group), max(group_confidences))
merged.append(merged_box)
else:
idx = text_title_boxes[i][0]
merged.append(data[idx])
i += 1
remaining = [data[i] for i in range(len(data)) if i not in skip_indices and data[i].clsid not in merged_labels]
return merged + remaining