You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

231 lines
8.3 KiB
Python

1 month ago
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import time
import os
import ast
import argparse
from typing import List, Tuple
import numpy as np
class LayoutBox(object):
def __init__(self, clsid: int, pos: List[float], confidence: float):
self.clsid = clsid
self.pos = pos
self.confidence = confidence
class PageDetectionResult(object):
def __init__(self, boxes: List[LayoutBox], image: np.ndarray):
1 month ago
self.boxes = boxes
self.image = image
1 month ago
class Times(object):
def __init__(self):
self.time = 0.
# start time
self.st = 0.
# end time
self.et = 0.
def start(self):
self.st = time.time()
def end(self, repeats=1, accumulative=True):
self.et = time.time()
if accumulative:
self.time += (self.et - self.st) / repeats
else:
self.time = (self.et - self.st) / repeats
def reset(self):
self.time = 0.
self.st = 0.
self.et = 0.
def value(self):
return round(self.time, 4)
class Timer(Times):
def __init__(self, with_tracker=False):
super(Timer, self).__init__()
self.with_tracker = with_tracker
self.preprocess_time_s = Times()
self.inference_time_s = Times()
self.postprocess_time_s = Times()
self.tracking_time_s = Times()
self.img_num = 0
def info(self, average=False):
pre_time = self.preprocess_time_s.value()
infer_time = self.inference_time_s.value()
post_time = self.postprocess_time_s.value()
track_time = self.tracking_time_s.value()
total_time = pre_time + infer_time + post_time
if self.with_tracker:
total_time = total_time + track_time
total_time = round(total_time, 4)
print("------------------ Inference Time Info ----------------------")
print("total_time(ms): {}, img_num: {}".format(total_time * 1000,
self.img_num))
preprocess_time = round(pre_time / max(1, self.img_num),
4) if average else pre_time
postprocess_time = round(post_time / max(1, self.img_num),
4) if average else post_time
inference_time = round(infer_time / max(1, self.img_num),
4) if average else infer_time
tracking_time = round(track_time / max(1, self.img_num),
4) if average else track_time
average_latency = total_time / max(1, self.img_num)
qps = 0
if total_time > 0:
qps = 1 / average_latency
print("average latency time(ms): {:.2f}, QPS: {:2f}".format(
average_latency * 1000, qps))
if self.with_tracker:
print(
"preprocess_time(ms): {:.2f}, inference_time(ms): {:.2f}, postprocess_time(ms): {:.2f}, tracking_time(ms): {:.2f}".
format(preprocess_time * 1000, inference_time * 1000,
postprocess_time * 1000, tracking_time * 1000))
else:
print(
"preprocess_time(ms): {:.2f}, inference_time(ms): {:.2f}, postprocess_time(ms): {:.2f}".
format(preprocess_time * 1000, inference_time * 1000,
postprocess_time * 1000))
def report(self, average=False):
dic = {}
pre_time = self.preprocess_time_s.value()
infer_time = self.inference_time_s.value()
post_time = self.postprocess_time_s.value()
track_time = self.tracking_time_s.value()
dic['preprocess_time_s'] = round(pre_time / max(1, self.img_num),
4) if average else pre_time
dic['inference_time_s'] = round(infer_time / max(1, self.img_num),
4) if average else infer_time
dic['postprocess_time_s'] = round(post_time / max(1, self.img_num),
4) if average else post_time
dic['img_num'] = self.img_num
total_time = pre_time + infer_time + post_time
if self.with_tracker:
dic['tracking_time_s'] = round(track_time / max(1, self.img_num),
4) if average else track_time
total_time = total_time + track_time
dic['total_time_s'] = round(total_time, 4)
return dic
def iou(box1, box2):
"""计算两个框的 IoU交并比"""
x1 = max(box1[0], box2[0])
y1 = max(box1[1], box2[1])
x2 = min(box1[2], box2[2])
y2 = min(box1[3], box2[3])
inter_area = max(0, x2 - x1) * max(0, y2 - y1)
box1_area = (box1[2] - box1[0]) * (box1[3] - box1[1])
box2_area = (box2[2] - box2[0]) * (box2[3] - box2[1])
union_area = box1_area + box2_area - inter_area
return inter_area / union_area if union_area != 0 else 0
def non_max_suppression(boxes, scores, iou_threshold):
"""非极大值抑制"""
if not boxes:
return []
indices = np.argsort(scores)[::-1]
selected_boxes = []
while len(indices) > 0:
current = indices[0]
selected_boxes.append(current)
remaining = indices[1:]
filtered_indices = []
for i in remaining:
if iou(boxes[current], boxes[i]) <= iou_threshold:
filtered_indices.append(i)
indices = np.array(filtered_indices)
return selected_boxes
def is_box_inside(inner, outer):
"""判断inner box是否完全在outer box内"""
return (outer[0] <= inner[0] and outer[1] <= inner[1] and
outer[2] >= inner[2] and outer[3] >= inner[3])
def boxes_overlap(box1: List[float], box2: List[float]) -> bool:
"""判断两个框是否有交集"""
x1_max = max(box1[0], box2[0])
y1_max = max(box1[1], box2[1])
x2_min = min(box1[2], box2[2])
y2_min = min(box1[3], box2[3])
return x1_max < x2_min and y1_max < y2_min
def merge_boxes(boxes: List[List[float]]) -> List[float]:
x1 = min(box[0] for box in boxes)
y1 = min(box[1] for box in boxes)
x2 = max(box[2] for box in boxes)
y2 = max(box[3] for box in boxes)
return [x1, y1, x2, y2]
def merge_text_and_title_boxes(data: List[LayoutBox], merged_labels: Tuple[int], other_labels: Tuple[int], merged_box_label: int) -> List[LayoutBox]:
1 month ago
text_title_boxes = [(i, box) for i, box in enumerate(data) if box.clsid in merged_labels]
other_boxes = [box.pos for box in data if box.clsid in other_labels]
1 month ago
text_title_boxes.sort(key=lambda x: x[1].pos[1]) # sort by y1
merged = []
skip_indices = set()
i = 0
while i < len(text_title_boxes):
if i in skip_indices:
i += 1
continue
current_group = [text_title_boxes[i][1].pos]
group_confidences = [text_title_boxes[i][1].confidence]
j = i + 1
while j < len(text_title_boxes):
candidate_box = text_title_boxes[j][1].pos
tentative_merge = merge_boxes(current_group + [candidate_box])
has_intruder = any(boxes_overlap(other, tentative_merge) for other in other_boxes)
if has_intruder:
break
else:
current_group.append(candidate_box)
group_confidences.append(text_title_boxes[j][1].confidence)
skip_indices.add(j)
j += 1
if len(current_group) > 1:
merged_box = LayoutBox(merged_box_label, merge_boxes(current_group), max(group_confidences))
1 month ago
merged.append(merged_box)
else:
idx = text_title_boxes[i][0]
merged.append(data[idx])
i += 1
remaining = [data[i] for i in range(len(data)) if i not in skip_indices and data[i].clsid not in merged_labels]
return merged + remaining