# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import time import os import ast import argparse from typing import List, Tuple import numpy as np class LayoutBox(object): def __init__(self, clsid: int, pos: List[float], confidence: float): self.clsid = clsid self.pos = pos self.confidence = confidence class PageDetectionResult(object): def __init__(self, boxes: List[LayoutBox], image: np.ndarray): self.boxes = boxes self.image = image class Times(object): def __init__(self): self.time = 0. # start time self.st = 0. # end time self.et = 0. def start(self): self.st = time.time() def end(self, repeats=1, accumulative=True): self.et = time.time() if accumulative: self.time += (self.et - self.st) / repeats else: self.time = (self.et - self.st) / repeats def reset(self): self.time = 0. self.st = 0. self.et = 0. def value(self): return round(self.time, 4) class Timer(Times): def __init__(self, with_tracker=False): super(Timer, self).__init__() self.with_tracker = with_tracker self.preprocess_time_s = Times() self.inference_time_s = Times() self.postprocess_time_s = Times() self.tracking_time_s = Times() self.img_num = 0 def info(self, average=False): pre_time = self.preprocess_time_s.value() infer_time = self.inference_time_s.value() post_time = self.postprocess_time_s.value() track_time = self.tracking_time_s.value() total_time = pre_time + infer_time + post_time if self.with_tracker: total_time = total_time + track_time total_time = round(total_time, 4) print("------------------ Inference Time Info ----------------------") print("total_time(ms): {}, img_num: {}".format(total_time * 1000, self.img_num)) preprocess_time = round(pre_time / max(1, self.img_num), 4) if average else pre_time postprocess_time = round(post_time / max(1, self.img_num), 4) if average else post_time inference_time = round(infer_time / max(1, self.img_num), 4) if average else infer_time tracking_time = round(track_time / max(1, self.img_num), 4) if average else track_time average_latency = total_time / max(1, self.img_num) qps = 0 if total_time > 0: qps = 1 / average_latency print("average latency time(ms): {:.2f}, QPS: {:2f}".format( average_latency * 1000, qps)) if self.with_tracker: print( "preprocess_time(ms): {:.2f}, inference_time(ms): {:.2f}, postprocess_time(ms): {:.2f}, tracking_time(ms): {:.2f}". format(preprocess_time * 1000, inference_time * 1000, postprocess_time * 1000, tracking_time * 1000)) else: print( "preprocess_time(ms): {:.2f}, inference_time(ms): {:.2f}, postprocess_time(ms): {:.2f}". format(preprocess_time * 1000, inference_time * 1000, postprocess_time * 1000)) def report(self, average=False): dic = {} pre_time = self.preprocess_time_s.value() infer_time = self.inference_time_s.value() post_time = self.postprocess_time_s.value() track_time = self.tracking_time_s.value() dic['preprocess_time_s'] = round(pre_time / max(1, self.img_num), 4) if average else pre_time dic['inference_time_s'] = round(infer_time / max(1, self.img_num), 4) if average else infer_time dic['postprocess_time_s'] = round(post_time / max(1, self.img_num), 4) if average else post_time dic['img_num'] = self.img_num total_time = pre_time + infer_time + post_time if self.with_tracker: dic['tracking_time_s'] = round(track_time / max(1, self.img_num), 4) if average else track_time total_time = total_time + track_time dic['total_time_s'] = round(total_time, 4) return dic def iou(box1, box2): """计算两个框的 IoU(交并比)""" x1 = max(box1[0], box2[0]) y1 = max(box1[1], box2[1]) x2 = min(box1[2], box2[2]) y2 = min(box1[3], box2[3]) inter_area = max(0, x2 - x1) * max(0, y2 - y1) box1_area = (box1[2] - box1[0]) * (box1[3] - box1[1]) box2_area = (box2[2] - box2[0]) * (box2[3] - box2[1]) union_area = box1_area + box2_area - inter_area return inter_area / union_area if union_area != 0 else 0 def non_max_suppression(boxes, scores, iou_threshold): """非极大值抑制""" if not boxes: return [] indices = np.argsort(scores)[::-1] selected_boxes = [] while len(indices) > 0: current = indices[0] selected_boxes.append(current) remaining = indices[1:] filtered_indices = [] for i in remaining: if iou(boxes[current], boxes[i]) <= iou_threshold: filtered_indices.append(i) indices = np.array(filtered_indices) return selected_boxes def is_box_inside(inner, outer): """判断inner box是否完全在outer box内""" return (outer[0] <= inner[0] and outer[1] <= inner[1] and outer[2] >= inner[2] and outer[3] >= inner[3]) def boxes_overlap(box1: List[float], box2: List[float]) -> bool: """判断两个框是否有交集""" x1_max = max(box1[0], box2[0]) y1_max = max(box1[1], box2[1]) x2_min = min(box1[2], box2[2]) y2_min = min(box1[3], box2[3]) return x1_max < x2_min and y1_max < y2_min def merge_boxes(boxes: List[List[float]]) -> List[float]: x1 = min(box[0] for box in boxes) y1 = min(box[1] for box in boxes) x2 = max(box[2] for box in boxes) y2 = max(box[3] for box in boxes) return [x1, y1, x2, y2] def merge_text_and_title_boxes(data: List[LayoutBox], merged_labels: Tuple[int], other_labels: Tuple[int], merged_box_label: int) -> List[LayoutBox]: text_title_boxes = [(i, box) for i, box in enumerate(data) if box.clsid in merged_labels] other_boxes = [box.pos for box in data if box.clsid in other_labels] text_title_boxes.sort(key=lambda x: x[1].pos[1]) # sort by y1 merged = [] skip_indices = set() i = 0 while i < len(text_title_boxes): if i in skip_indices: i += 1 continue current_group = [text_title_boxes[i][1].pos] group_confidences = [text_title_boxes[i][1].confidence] j = i + 1 while j < len(text_title_boxes): candidate_box = text_title_boxes[j][1].pos tentative_merge = merge_boxes(current_group + [candidate_box]) has_intruder = any(boxes_overlap(other, tentative_merge) for other in other_boxes) if has_intruder: break else: current_group.append(candidate_box) group_confidences.append(text_title_boxes[j][1].confidence) skip_indices.add(j) j += 1 if len(current_group) > 1: merged_box = LayoutBox(merged_box_label, merge_boxes(current_group), max(group_confidences)) merged.append(merged_box) else: idx = text_title_boxes[i][0] merged.append(data[idx]) i += 1 remaining = [data[i] for i in range(len(data)) if i not in skip_indices and data[i].clsid not in merged_labels] return merged + remaining