# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#    http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
from ppocr.metrics.det_metric import DetMetric


class TableStructureMetric(object):
    def __init__(self, main_indicator="acc", eps=1e-6, del_thead_tbody=False, **kwargs):
        self.main_indicator = main_indicator
        self.eps = eps
        self.del_thead_tbody = del_thead_tbody
        self.reset()

    def __call__(self, pred_label, batch=None, *args, **kwargs):
        preds, labels = pred_label
        pred_structure_batch_list = preds["structure_batch_list"]
        gt_structure_batch_list = labels["structure_batch_list"]
        correct_num = 0
        all_num = 0
        for (pred, pred_conf), target in zip(
            pred_structure_batch_list, gt_structure_batch_list
        ):
            pred_str = "".join(pred)
            target_str = "".join(target)
            if self.del_thead_tbody:
                pred_str = (
                    pred_str.replace("<thead>", "")
                    .replace("</thead>", "")
                    .replace("<tbody>", "")
                    .replace("</tbody>", "")
                )
                target_str = (
                    target_str.replace("<thead>", "")
                    .replace("</thead>", "")
                    .replace("<tbody>", "")
                    .replace("</tbody>", "")
                )
            if pred_str == target_str:
                correct_num += 1
            all_num += 1
        self.correct_num += correct_num
        self.all_num += all_num

    def get_metric(self):
        """
        return metrics {
                 'acc': 0,
            }
        """
        acc = 1.0 * self.correct_num / (self.all_num + self.eps)
        self.reset()
        return {"acc": acc}

    def reset(self):
        self.correct_num = 0
        self.all_num = 0
        self.len_acc_num = 0
        self.token_nums = 0
        self.anys_dict = dict()


class TableMetric(object):
    def __init__(
        self,
        main_indicator="acc",
        compute_bbox_metric=False,
        box_format="xyxy",
        del_thead_tbody=False,
        **kwargs,
    ):
        """

        @param sub_metrics: configs of sub_metric
        @param main_matric: main_matric for save best_model
        @param kwargs:
        """
        self.structure_metric = TableStructureMetric(del_thead_tbody=del_thead_tbody)
        self.bbox_metric = DetMetric() if compute_bbox_metric else None
        self.main_indicator = main_indicator
        self.box_format = box_format
        self.reset()

    def __call__(self, pred_label, batch=None, *args, **kwargs):
        self.structure_metric(pred_label)
        if self.bbox_metric is not None:
            self.bbox_metric(*self.prepare_bbox_metric_input(pred_label))

    def prepare_bbox_metric_input(self, pred_label):
        pred_bbox_batch_list = []
        gt_ignore_tags_batch_list = []
        gt_bbox_batch_list = []
        preds, labels = pred_label

        batch_num = len(preds["bbox_batch_list"])
        for batch_idx in range(batch_num):
            # pred
            pred_bbox_list = [
                self.format_box(pred_box)
                for pred_box in preds["bbox_batch_list"][batch_idx]
            ]
            pred_bbox_batch_list.append({"points": pred_bbox_list})

            # gt
            gt_bbox_list = []
            gt_ignore_tags_list = []
            for gt_box in labels["bbox_batch_list"][batch_idx]:
                gt_bbox_list.append(self.format_box(gt_box))
                gt_ignore_tags_list.append(0)
            gt_bbox_batch_list.append(gt_bbox_list)
            gt_ignore_tags_batch_list.append(gt_ignore_tags_list)

        return [
            pred_bbox_batch_list,
            [0, 0, gt_bbox_batch_list, gt_ignore_tags_batch_list],
        ]

    def get_metric(self):
        structure_metric = self.structure_metric.get_metric()
        if self.bbox_metric is None:
            return structure_metric
        bbox_metric = self.bbox_metric.get_metric()
        if self.main_indicator == self.bbox_metric.main_indicator:
            output = bbox_metric
            for sub_key in structure_metric:
                output["structure_metric_{}".format(sub_key)] = structure_metric[
                    sub_key
                ]
        else:
            output = structure_metric
            for sub_key in bbox_metric:
                output["bbox_metric_{}".format(sub_key)] = bbox_metric[sub_key]
        return output

    def reset(self):
        self.structure_metric.reset()
        if self.bbox_metric is not None:
            self.bbox_metric.reset()

    def format_box(self, box):
        if self.box_format == "xyxy":
            x1, y1, x2, y2 = box
            box = [[x1, y1], [x2, y1], [x2, y2], [x1, y2]]
        elif self.box_format == "xywh":
            x, y, w, h = box
            x1, y1, x2, y2 = x - w // 2, y - h // 2, x + w // 2, y + h // 2
            box = [[x1, y1], [x2, y1], [x2, y2], [x1, y2]]
        elif self.box_format == "xyxyxyxy":
            x1, y1, x2, y2, x3, y3, x4, y4 = box
            box = [[x1, y1], [x2, y2], [x3, y3], [x4, y4]]
        return box