|
|
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
|
|
#
|
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
|
# you may not use this file except in compliance with the License.
|
|
|
# You may obtain a copy of the License at
|
|
|
#
|
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
|
#
|
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
|
# See the License for the specific language governing permissions and
|
|
|
# limitations under the License.
|
|
|
|
|
|
import time
|
|
|
import os
|
|
|
import ast
|
|
|
import argparse
|
|
|
from typing import List, Tuple
|
|
|
import numpy as np
|
|
|
|
|
|
|
|
|
class LayoutBox(object):
|
|
|
def __init__(self, clsid: int, pos: List[float], confidence: float):
|
|
|
self.clsid = clsid
|
|
|
self.pos = pos
|
|
|
self.confidence = confidence
|
|
|
|
|
|
|
|
|
class PageDetectionResult(object):
|
|
|
def __init__(self, boxes: List[LayoutBox], image: np.ndarray):
|
|
|
self.boxes = boxes
|
|
|
self.image = image
|
|
|
|
|
|
|
|
|
class Times(object):
|
|
|
def __init__(self):
|
|
|
self.time = 0.
|
|
|
# start time
|
|
|
self.st = 0.
|
|
|
# end time
|
|
|
self.et = 0.
|
|
|
|
|
|
def start(self):
|
|
|
self.st = time.time()
|
|
|
|
|
|
def end(self, repeats=1, accumulative=True):
|
|
|
self.et = time.time()
|
|
|
if accumulative:
|
|
|
self.time += (self.et - self.st) / repeats
|
|
|
else:
|
|
|
self.time = (self.et - self.st) / repeats
|
|
|
|
|
|
def reset(self):
|
|
|
self.time = 0.
|
|
|
self.st = 0.
|
|
|
self.et = 0.
|
|
|
|
|
|
def value(self):
|
|
|
return round(self.time, 4)
|
|
|
|
|
|
|
|
|
class Timer(Times):
|
|
|
def __init__(self, with_tracker=False):
|
|
|
super(Timer, self).__init__()
|
|
|
self.with_tracker = with_tracker
|
|
|
self.preprocess_time_s = Times()
|
|
|
self.inference_time_s = Times()
|
|
|
self.postprocess_time_s = Times()
|
|
|
self.tracking_time_s = Times()
|
|
|
self.img_num = 0
|
|
|
|
|
|
def info(self, average=False):
|
|
|
pre_time = self.preprocess_time_s.value()
|
|
|
infer_time = self.inference_time_s.value()
|
|
|
post_time = self.postprocess_time_s.value()
|
|
|
track_time = self.tracking_time_s.value()
|
|
|
|
|
|
total_time = pre_time + infer_time + post_time
|
|
|
if self.with_tracker:
|
|
|
total_time = total_time + track_time
|
|
|
total_time = round(total_time, 4)
|
|
|
print("------------------ Inference Time Info ----------------------")
|
|
|
print("total_time(ms): {}, img_num: {}".format(total_time * 1000,
|
|
|
self.img_num))
|
|
|
preprocess_time = round(pre_time / max(1, self.img_num),
|
|
|
4) if average else pre_time
|
|
|
postprocess_time = round(post_time / max(1, self.img_num),
|
|
|
4) if average else post_time
|
|
|
inference_time = round(infer_time / max(1, self.img_num),
|
|
|
4) if average else infer_time
|
|
|
tracking_time = round(track_time / max(1, self.img_num),
|
|
|
4) if average else track_time
|
|
|
|
|
|
average_latency = total_time / max(1, self.img_num)
|
|
|
qps = 0
|
|
|
if total_time > 0:
|
|
|
qps = 1 / average_latency
|
|
|
print("average latency time(ms): {:.2f}, QPS: {:2f}".format(
|
|
|
average_latency * 1000, qps))
|
|
|
if self.with_tracker:
|
|
|
print(
|
|
|
"preprocess_time(ms): {:.2f}, inference_time(ms): {:.2f}, postprocess_time(ms): {:.2f}, tracking_time(ms): {:.2f}".
|
|
|
format(preprocess_time * 1000, inference_time * 1000,
|
|
|
postprocess_time * 1000, tracking_time * 1000))
|
|
|
else:
|
|
|
print(
|
|
|
"preprocess_time(ms): {:.2f}, inference_time(ms): {:.2f}, postprocess_time(ms): {:.2f}".
|
|
|
format(preprocess_time * 1000, inference_time * 1000,
|
|
|
postprocess_time * 1000))
|
|
|
|
|
|
def report(self, average=False):
|
|
|
dic = {}
|
|
|
pre_time = self.preprocess_time_s.value()
|
|
|
infer_time = self.inference_time_s.value()
|
|
|
post_time = self.postprocess_time_s.value()
|
|
|
track_time = self.tracking_time_s.value()
|
|
|
|
|
|
dic['preprocess_time_s'] = round(pre_time / max(1, self.img_num),
|
|
|
4) if average else pre_time
|
|
|
dic['inference_time_s'] = round(infer_time / max(1, self.img_num),
|
|
|
4) if average else infer_time
|
|
|
dic['postprocess_time_s'] = round(post_time / max(1, self.img_num),
|
|
|
4) if average else post_time
|
|
|
dic['img_num'] = self.img_num
|
|
|
total_time = pre_time + infer_time + post_time
|
|
|
if self.with_tracker:
|
|
|
dic['tracking_time_s'] = round(track_time / max(1, self.img_num),
|
|
|
4) if average else track_time
|
|
|
total_time = total_time + track_time
|
|
|
dic['total_time_s'] = round(total_time, 4)
|
|
|
return dic
|
|
|
|
|
|
|
|
|
def iou(box1, box2):
|
|
|
"""计算两个框的 IoU(交并比)"""
|
|
|
x1 = max(box1[0], box2[0])
|
|
|
y1 = max(box1[1], box2[1])
|
|
|
x2 = min(box1[2], box2[2])
|
|
|
y2 = min(box1[3], box2[3])
|
|
|
|
|
|
inter_area = max(0, x2 - x1) * max(0, y2 - y1)
|
|
|
box1_area = (box1[2] - box1[0]) * (box1[3] - box1[1])
|
|
|
box2_area = (box2[2] - box2[0]) * (box2[3] - box2[1])
|
|
|
|
|
|
union_area = box1_area + box2_area - inter_area
|
|
|
return inter_area / union_area if union_area != 0 else 0
|
|
|
|
|
|
def non_max_suppression(boxes, scores, iou_threshold):
|
|
|
"""非极大值抑制"""
|
|
|
if not boxes:
|
|
|
return []
|
|
|
indices = np.argsort(scores)[::-1]
|
|
|
selected_boxes = []
|
|
|
|
|
|
while len(indices) > 0:
|
|
|
current = indices[0]
|
|
|
selected_boxes.append(current)
|
|
|
remaining = indices[1:]
|
|
|
|
|
|
filtered_indices = []
|
|
|
for i in remaining:
|
|
|
if iou(boxes[current], boxes[i]) <= iou_threshold:
|
|
|
filtered_indices.append(i)
|
|
|
|
|
|
indices = np.array(filtered_indices)
|
|
|
|
|
|
return selected_boxes
|
|
|
|
|
|
|
|
|
def is_box_inside(inner, outer):
|
|
|
"""判断inner box是否完全在outer box内"""
|
|
|
return (outer[0] <= inner[0] and outer[1] <= inner[1] and
|
|
|
outer[2] >= inner[2] and outer[3] >= inner[3])
|
|
|
|
|
|
|
|
|
def boxes_overlap(box1: List[float], box2: List[float]) -> bool:
|
|
|
"""判断两个框是否有交集"""
|
|
|
x1_max = max(box1[0], box2[0])
|
|
|
y1_max = max(box1[1], box2[1])
|
|
|
x2_min = min(box1[2], box2[2])
|
|
|
y2_min = min(box1[3], box2[3])
|
|
|
return x1_max < x2_min and y1_max < y2_min
|
|
|
|
|
|
|
|
|
def merge_boxes(boxes: List[List[float]]) -> List[float]:
|
|
|
x1 = min(box[0] for box in boxes)
|
|
|
y1 = min(box[1] for box in boxes)
|
|
|
x2 = max(box[2] for box in boxes)
|
|
|
y2 = max(box[3] for box in boxes)
|
|
|
return [x1, y1, x2, y2]
|
|
|
|
|
|
|
|
|
def merge_text_and_title_boxes(data: List[LayoutBox], merged_labels: Tuple[int], other_labels: Tuple[int], merged_box_label: int) -> List[LayoutBox]:
|
|
|
text_title_boxes = [(i, box) for i, box in enumerate(data) if box.clsid in merged_labels]
|
|
|
other_boxes = [box.pos for box in data if box.clsid in other_labels]
|
|
|
|
|
|
text_title_boxes.sort(key=lambda x: x[1].pos[1]) # sort by y1
|
|
|
|
|
|
merged = []
|
|
|
skip_indices = set()
|
|
|
i = 0
|
|
|
while i < len(text_title_boxes):
|
|
|
if i in skip_indices:
|
|
|
i += 1
|
|
|
continue
|
|
|
current_group = [text_title_boxes[i][1].pos]
|
|
|
group_confidences = [text_title_boxes[i][1].confidence]
|
|
|
j = i + 1
|
|
|
while j < len(text_title_boxes):
|
|
|
candidate_box = text_title_boxes[j][1].pos
|
|
|
tentative_merge = merge_boxes(current_group + [candidate_box])
|
|
|
has_intruder = any(boxes_overlap(other, tentative_merge) for other in other_boxes)
|
|
|
if has_intruder:
|
|
|
break
|
|
|
else:
|
|
|
current_group.append(candidate_box)
|
|
|
group_confidences.append(text_title_boxes[j][1].confidence)
|
|
|
skip_indices.add(j)
|
|
|
j += 1
|
|
|
if len(current_group) > 1:
|
|
|
merged_box = LayoutBox(merged_box_label, merge_boxes(current_group), max(group_confidences))
|
|
|
merged.append(merged_box)
|
|
|
else:
|
|
|
idx = text_title_boxes[i][0]
|
|
|
merged.append(data[idx])
|
|
|
i += 1
|
|
|
|
|
|
remaining = [data[i] for i in range(len(data)) if i not in skip_indices and data[i].clsid not in merged_labels]
|
|
|
return merged + remaining
|