# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import time
import os
import ast
import argparse
from typing import List, Tuple
import numpy as np


class LayoutBox(object):
    def __init__(self, clsid: int, pos: List[float], confidence: float):
        self.clsid = clsid
        self.pos = pos
        self.confidence = confidence


class PageDetectionResult(object):
    def __init__(self, boxes: List[LayoutBox], image: np.ndarray):
        self.boxes = boxes
        self.image = image


class Times(object):
    def __init__(self):
        self.time = 0.
        # start time
        self.st = 0.
        # end time
        self.et = 0.

    def start(self):
        self.st = time.time()

    def end(self, repeats=1, accumulative=True):
        self.et = time.time()
        if accumulative:
            self.time += (self.et - self.st) / repeats
        else:
            self.time = (self.et - self.st) / repeats

    def reset(self):
        self.time = 0.
        self.st = 0.
        self.et = 0.

    def value(self):
        return round(self.time, 4)


class Timer(Times):
    def __init__(self, with_tracker=False):
        super(Timer, self).__init__()
        self.with_tracker = with_tracker
        self.preprocess_time_s = Times()
        self.inference_time_s = Times()
        self.postprocess_time_s = Times()
        self.tracking_time_s = Times()
        self.img_num = 0

    def info(self, average=False):
        pre_time = self.preprocess_time_s.value()
        infer_time = self.inference_time_s.value()
        post_time = self.postprocess_time_s.value()
        track_time = self.tracking_time_s.value()

        total_time = pre_time + infer_time + post_time
        if self.with_tracker:
            total_time = total_time + track_time
        total_time = round(total_time, 4)
        print("------------------ Inference Time Info ----------------------")
        print("total_time(ms): {}, img_num: {}".format(total_time * 1000,
                                                       self.img_num))
        preprocess_time = round(pre_time / max(1, self.img_num),
                                4) if average else pre_time
        postprocess_time = round(post_time / max(1, self.img_num),
                                 4) if average else post_time
        inference_time = round(infer_time / max(1, self.img_num),
                               4) if average else infer_time
        tracking_time = round(track_time / max(1, self.img_num),
                              4) if average else track_time

        average_latency = total_time / max(1, self.img_num)
        qps = 0
        if total_time > 0:
            qps = 1 / average_latency
        print("average latency time(ms): {:.2f}, QPS: {:2f}".format(
            average_latency * 1000, qps))
        if self.with_tracker:
            print(
                "preprocess_time(ms): {:.2f}, inference_time(ms): {:.2f}, postprocess_time(ms): {:.2f}, tracking_time(ms): {:.2f}".
                format(preprocess_time * 1000, inference_time * 1000,
                       postprocess_time * 1000, tracking_time * 1000))
        else:
            print(
                "preprocess_time(ms): {:.2f}, inference_time(ms): {:.2f}, postprocess_time(ms): {:.2f}".
                format(preprocess_time * 1000, inference_time * 1000,
                       postprocess_time * 1000))

    def report(self, average=False):
        dic = {}
        pre_time = self.preprocess_time_s.value()
        infer_time = self.inference_time_s.value()
        post_time = self.postprocess_time_s.value()
        track_time = self.tracking_time_s.value()

        dic['preprocess_time_s'] = round(pre_time / max(1, self.img_num),
                                         4) if average else pre_time
        dic['inference_time_s'] = round(infer_time / max(1, self.img_num),
                                        4) if average else infer_time
        dic['postprocess_time_s'] = round(post_time / max(1, self.img_num),
                                          4) if average else post_time
        dic['img_num'] = self.img_num
        total_time = pre_time + infer_time + post_time
        if self.with_tracker:
            dic['tracking_time_s'] = round(track_time / max(1, self.img_num),
                                           4) if average else track_time
            total_time = total_time + track_time
        dic['total_time_s'] = round(total_time, 4)
        return dic


def iou(box1, box2):
    """计算两个框的 IoU(交并比)"""
    x1 = max(box1[0], box2[0])
    y1 = max(box1[1], box2[1])
    x2 = min(box1[2], box2[2])
    y2 = min(box1[3], box2[3])
    
    inter_area = max(0, x2 - x1) * max(0, y2 - y1)
    box1_area = (box1[2] - box1[0]) * (box1[3] - box1[1])
    box2_area = (box2[2] - box2[0]) * (box2[3] - box2[1])
    
    union_area = box1_area + box2_area - inter_area
    return inter_area / union_area if union_area != 0 else 0

def non_max_suppression(boxes, scores, iou_threshold):
    """非极大值抑制"""
    if not boxes:
        return []
    indices = np.argsort(scores)[::-1]
    selected_boxes = []
    
    while len(indices) > 0:
        current = indices[0]
        selected_boxes.append(current)
        remaining = indices[1:]
        
        filtered_indices = []
        for i in remaining:
            if iou(boxes[current], boxes[i]) <= iou_threshold:
                filtered_indices.append(i)
        
        indices = np.array(filtered_indices)
    
    return selected_boxes


def is_box_inside(inner, outer):
    """判断inner box是否完全在outer box内"""
    return (outer[0] <= inner[0] and outer[1] <= inner[1] and
            outer[2] >= inner[2] and outer[3] >= inner[3])


def boxes_overlap(box1: List[float], box2: List[float]) -> bool:
    """判断两个框是否有交集"""
    x1_max = max(box1[0], box2[0])
    y1_max = max(box1[1], box2[1])
    x2_min = min(box1[2], box2[2])
    y2_min = min(box1[3], box2[3])
    return x1_max < x2_min and y1_max < y2_min


def merge_boxes(boxes: List[List[float]]) -> List[float]:
    x1 = min(box[0] for box in boxes)
    y1 = min(box[1] for box in boxes)
    x2 = max(box[2] for box in boxes)
    y2 = max(box[3] for box in boxes)
    return [x1, y1, x2, y2]


def merge_text_and_title_boxes(data: List[LayoutBox], merged_labels: Tuple[int], other_labels: Tuple[int], merged_box_label: int) -> List[LayoutBox]:
    text_title_boxes = [(i, box) for i, box in enumerate(data) if box.clsid in merged_labels]
    other_boxes = [box.pos for box in data if box.clsid in other_labels]

    text_title_boxes.sort(key=lambda x: x[1].pos[1])  # sort by y1

    merged = []
    skip_indices = set()
    i = 0
    while i < len(text_title_boxes):
        if i in skip_indices:
            i += 1
            continue
        current_group = [text_title_boxes[i][1].pos]
        group_confidences = [text_title_boxes[i][1].confidence]
        j = i + 1
        while j < len(text_title_boxes):
            candidate_box = text_title_boxes[j][1].pos
            tentative_merge = merge_boxes(current_group + [candidate_box])
            has_intruder = any(boxes_overlap(other, tentative_merge) for other in other_boxes)
            if has_intruder:
                break
            else:
                current_group.append(candidate_box)
                group_confidences.append(text_title_boxes[j][1].confidence)
                skip_indices.add(j)
                j += 1
        if len(current_group) > 1:
            merged_box = LayoutBox(merged_box_label, merge_boxes(current_group), max(group_confidences))
            merged.append(merged_box)
        else:
            idx = text_title_boxes[i][0]
            merged.append(data[idx])
        i += 1

    remaining = [data[i] for i in range(len(data)) if i not in skip_indices and data[i].clsid not in merged_labels]
    return merged + remaining