You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

102 lines
3.0 KiB
Python

# -*- coding: utf-8 -*-
# @Time : 2019/12/4 14:39
# @Author : zhoujun
import paddle
import paddle.nn as nn
class BalanceCrossEntropyLoss(nn.Layer):
"""
Balanced cross entropy loss.
Shape:
- Input: :math:`(N, 1, H, W)`
- GT: :math:`(N, 1, H, W)`, same shape as the input
- Mask: :math:`(N, H, W)`, same spatial shape as the input
- Output: scalar.
"""
def __init__(self, negative_ratio=3.0, eps=1e-6):
super(BalanceCrossEntropyLoss, self).__init__()
self.negative_ratio = negative_ratio
self.eps = eps
def forward(
self,
pred: paddle.Tensor,
gt: paddle.Tensor,
mask: paddle.Tensor,
return_origin=False,
):
"""
Args:
pred: shape :math:`(N, 1, H, W)`, the prediction of network
gt: shape :math:`(N, 1, H, W)`, the target
mask: shape :math:`(N, H, W)`, the mask indicates positive regions
"""
positive = gt * mask
negative = (1 - gt) * mask
positive_count = int(positive.sum())
negative_count = min(
int(negative.sum()), int(positive_count * self.negative_ratio)
)
loss = nn.functional.binary_cross_entropy(pred, gt, reduction="none")
positive_loss = loss * positive
negative_loss = loss * negative
negative_loss, _ = negative_loss.reshape([-1]).topk(negative_count)
balance_loss = (positive_loss.sum() + negative_loss.sum()) / (
positive_count + negative_count + self.eps
)
if return_origin:
return balance_loss, loss
return balance_loss
class DiceLoss(nn.Layer):
"""
Loss function from https://arxiv.org/abs/1707.03237,
where iou computation is introduced heatmap manner to measure the
diversity bwtween tow heatmaps.
"""
def __init__(self, eps=1e-6):
super(DiceLoss, self).__init__()
self.eps = eps
def forward(self, pred: paddle.Tensor, gt, mask, weights=None):
"""
pred: one or two heatmaps of shape (N, 1, H, W),
the losses of tow heatmaps are added together.
gt: (N, 1, H, W)
mask: (N, H, W)
"""
return self._compute(pred, gt, mask, weights)
def _compute(self, pred, gt, mask, weights):
if len(pred.shape) == 4:
pred = pred[:, 0, :, :]
gt = gt[:, 0, :, :]
assert pred.shape == gt.shape
assert pred.shape == mask.shape
if weights is not None:
assert weights.shape == mask.shape
mask = weights * mask
intersection = (pred * gt * mask).sum()
union = (pred * mask).sum() + (gt * mask).sum() + self.eps
loss = 1 - 2.0 * intersection / union
assert loss <= 1
return loss
class MaskL1Loss(nn.Layer):
def __init__(self, eps=1e-6):
super(MaskL1Loss, self).__init__()
self.eps = eps
def forward(self, pred: paddle.Tensor, gt, mask):
loss = (paddle.abs(pred - gt) * mask).sum() / (mask.sum() + self.eps)
return loss