# -*- coding: utf-8 -*-
# @Time    : 2019/12/4 14:39
# @Author  : zhoujun
import paddle
import paddle.nn as nn


class BalanceCrossEntropyLoss(nn.Layer):
    """
    Balanced cross entropy loss.
    Shape:
        - Input: :math:`(N, 1, H, W)`
        - GT: :math:`(N, 1, H, W)`, same shape as the input
        - Mask: :math:`(N, H, W)`, same spatial shape as the input
        - Output: scalar.

    """

    def __init__(self, negative_ratio=3.0, eps=1e-6):
        super(BalanceCrossEntropyLoss, self).__init__()
        self.negative_ratio = negative_ratio
        self.eps = eps

    def forward(
        self,
        pred: paddle.Tensor,
        gt: paddle.Tensor,
        mask: paddle.Tensor,
        return_origin=False,
    ):
        """
        Args:
            pred: shape :math:`(N, 1, H, W)`, the prediction of network
            gt: shape :math:`(N, 1, H, W)`, the target
            mask: shape :math:`(N, H, W)`, the mask indicates positive regions
        """
        positive = gt * mask
        negative = (1 - gt) * mask
        positive_count = int(positive.sum())
        negative_count = min(
            int(negative.sum()), int(positive_count * self.negative_ratio)
        )
        loss = nn.functional.binary_cross_entropy(pred, gt, reduction="none")
        positive_loss = loss * positive
        negative_loss = loss * negative
        negative_loss, _ = negative_loss.reshape([-1]).topk(negative_count)

        balance_loss = (positive_loss.sum() + negative_loss.sum()) / (
            positive_count + negative_count + self.eps
        )

        if return_origin:
            return balance_loss, loss
        return balance_loss


class DiceLoss(nn.Layer):
    """
    Loss function from https://arxiv.org/abs/1707.03237,
    where iou computation is introduced heatmap manner to measure the
    diversity bwtween tow heatmaps.
    """

    def __init__(self, eps=1e-6):
        super(DiceLoss, self).__init__()
        self.eps = eps

    def forward(self, pred: paddle.Tensor, gt, mask, weights=None):
        """
        pred: one or two heatmaps of shape (N, 1, H, W),
            the losses of tow heatmaps are added together.
        gt: (N, 1, H, W)
        mask: (N, H, W)
        """
        return self._compute(pred, gt, mask, weights)

    def _compute(self, pred, gt, mask, weights):
        if len(pred.shape) == 4:
            pred = pred[:, 0, :, :]
            gt = gt[:, 0, :, :]
        assert pred.shape == gt.shape
        assert pred.shape == mask.shape
        if weights is not None:
            assert weights.shape == mask.shape
            mask = weights * mask
        intersection = (pred * gt * mask).sum()

        union = (pred * mask).sum() + (gt * mask).sum() + self.eps
        loss = 1 - 2.0 * intersection / union
        assert loss <= 1
        return loss


class MaskL1Loss(nn.Layer):
    def __init__(self, eps=1e-6):
        super(MaskL1Loss, self).__init__()
        self.eps = eps

    def forward(self, pred: paddle.Tensor, gt, mask):
        loss = (paddle.abs(pred - gt) * mask).sum() / (mask.sum() + self.eps)
        return loss